blob: 10e15cc31655b44586e7c39bc004c354338361ca [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include <algorithm>
#include <functional>
#include <iostream>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <string>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/ascii.h"
#include "absl/strings/escaping.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
using absl::CEscape;
using absl::StrAppend;
using absl::StrCat;
using absl::StrJoin;
HloInstruction* HloInstruction::AddInstruction(
std::unique_ptr<HloInstruction> derived_instruction) {
HloInstruction* derived =
parent()->AddInstruction(std::move(derived_instruction));
const bool has_prior_sharding = derived->has_sharding();
SetupDerivedInstruction(derived);
if (!has_prior_sharding && (derived->opcode() == HloOpcode::kReshape ||
derived->opcode() == HloOpcode::kTranspose)) {
derived->clear_sharding();
}
return derived;
}
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto,
const absl::flat_hash_map<int64_t, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
TF_RET_CHECK(!proto.opcode().empty());
HloOpcode opcode;
auto opcode_or = StringToHloOpcode(proto.opcode());
std::optional<ComparisonDirection> comparison_direction;
if (opcode_or.ok()) {
opcode = std::move(opcode_or).value();
} else {
// Unknown opcode. Try auto-upgrading deprecated "less-than",
// "greater-than", etc opcodes, which are now rolled into the kCompare
// opcode.
if (proto.opcode() == "equal-to") {
comparison_direction = ComparisonDirection::kEq;
} else if (proto.opcode() == "not-equal-to") {
comparison_direction = ComparisonDirection::kNe;
} else if (proto.opcode() == "greater-than-or-equal-to") {
comparison_direction = ComparisonDirection::kGe;
} else if (proto.opcode() == "greater-than") {
comparison_direction = ComparisonDirection::kGt;
} else if (proto.opcode() == "less-than-or-equal-to") {
comparison_direction = ComparisonDirection::kLe;
} else if (proto.opcode() == "less-than") {
comparison_direction = ComparisonDirection::kLt;
}
if (comparison_direction) {
opcode = HloOpcode::kCompare;
} else {
return InvalidArgument("Unknown opcode: %s", proto.opcode());
}
}
TF_RET_CHECK(proto.has_shape());
std::unique_ptr<HloInstruction> instruction;
const auto operands = [&instruction_map, &proto](int index) {
return instruction_map.at(proto.operand_ids(index));
};
const auto all_operands = [&instruction_map, &proto]() {
std::vector<HloInstruction*> result(proto.operand_ids_size());
std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
result.begin(), [&instruction_map](int64_t operand_id) {
return instruction_map.at(operand_id);
});
return result;
};
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
const auto all_computations = [&computation_map, &proto]() {
std::vector<HloComputation*> result(proto.called_computation_ids_size());
std::transform(proto.called_computation_ids().begin(),
proto.called_computation_ids().end(), result.begin(),
[&computation_map](int64_t computation_id) {
return computation_map.at(computation_id);
});
return result;
};
TF_RET_CHECK(
absl::c_all_of(proto.operand_ids(),
[&](int64_t id) { return instruction_map.contains(id); }))
<< proto.name() << " instruction contains invalid operand id(s)";
TF_RET_CHECK(
absl::c_all_of(proto.called_computation_ids(),
[&](int64_t id) { return computation_map.contains(id); }))
<< proto.name() << " instruction references invalid computation id(s)";
Shape shape(proto.shape());
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
std::optional<int> arity = HloOpcodeArity(opcode);
if (arity) {
TF_RET_CHECK(proto.operand_ids_size() == *arity)
<< proto.opcode() << " instruction should have " << *arity
<< " operands but sees " << proto.operand_ids_size();
}
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
instruction =
CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormInference:
instruction = CreateBatchNormInference(
shape, operands(0), operands(1), operands(2), operands(3),
operands(4), proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormGrad:
instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
operands(2), operands(3), operands(4),
proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kFft: {
std::vector<int64_t> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(shape, operands(0), proto.fft_type(),
absl::Span<const int64_t>(fft_length));
break;
}
case HloOpcode::kAsyncStart: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Async start instruction should have 1 called computation but "
"sees "
<< proto.called_computation_ids_size();
std::optional<int64_t> async_group_id;
if (proto.async_group_id() >= 0) {
async_group_id = proto.async_group_id();
}
instruction = CreateAsyncStart(shape, all_operands(), computations(0),
async_group_id);
break;
}
case HloOpcode::kAsyncUpdate: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Async update instruction should have 1 called computation but "
"sees "
<< proto.called_computation_ids_size();
std::optional<int64_t> async_group_id;
if (proto.async_group_id() >= 0) {
async_group_id = proto.async_group_id();
}
instruction = CreateAsyncUpdate(shape, operands(0), computations(0),
async_group_id);
break;
}
case HloOpcode::kAsyncDone: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Async done instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
std::optional<int64_t> async_group_id;
if (proto.async_group_id() >= 0) {
async_group_id = proto.async_group_id();
}
instruction =
CreateAsyncDone(shape, operands(0), computations(0), async_group_id);
break;
}
case HloOpcode::kCopyStart: {
instruction = CreateCopyStart(shape, operands(0),
proto.is_cross_program_prefetch());
break;
}
case HloOpcode::kCompare: {
// Auto-upgraded from deprecated opcode skips the following.
if (!comparison_direction) {
TF_ASSIGN_OR_RETURN(
comparison_direction,
StringToComparisonDirection(proto.comparison_direction()));
}
auto comparison_type_str = proto.comparison_type();
if (!comparison_type_str.empty()) {
// If a comparison type is specified, it *must* be valid.
TF_ASSIGN_OR_RETURN(auto comparison_type,
StringToComparisonType(comparison_type_str));
instruction = CreateCompare(shape, operands(0), operands(1),
*comparison_direction, comparison_type);
} else {
// Allow the specify of comparison type to be optional.
// The comparison type will be determined by the types of the operands.
instruction = CreateCompare(shape, operands(0), operands(1),
*comparison_direction);
}
break;
}
case HloOpcode::kTriangularSolve: {
instruction = CreateTriangularSolve(shape, operands(0), operands(1),
proto.triangular_solve_options());
break;
}
case HloOpcode::kCholesky: {
instruction =
CreateCholesky(shape, operands(0), proto.cholesky_options());
break;
}
case HloOpcode::kSend:
instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
proto.is_host_transfer());
break;
case HloOpcode::kSendDone:
instruction = CreateSendDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kRecv:
instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
proto.channel_id(), proto.is_host_transfer());
break;
case HloOpcode::kRecvDone:
instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kReverse:
instruction =
CreateReverse(shape, operands(0),
std::vector<int64_t>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kConcatenate:
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "Concatenate instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction =
CreateConcatenate(shape, all_operands(), proto.dimensions(0));
break;
case HloOpcode::kConditional: {
TF_RET_CHECK(proto.called_computation_ids_size() > 0)
<< "conditional should have at least 1 called computation";
if (operands(0)->shape().element_type() == PRED) {
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
<< "conditional should have exactly 2 called computations but got "
<< proto.called_computation_ids_size();
}
TF_RET_CHECK(proto.operand_ids_size() ==
proto.called_computation_ids_size() + 1)
<< "conditional should have one branch_index operand plus one "
"operand per called computation but got "
<< proto.operand_ids_size() << " operands for "
<< proto.called_computation_ids_size() << " branch computations";
auto cond_operands = all_operands();
instruction =
CreateConditional(shape, cond_operands[0], all_computations(),
absl::MakeSpan(cond_operands).subspan(1));
break;
}
case HloOpcode::kReduce:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
<< "Reduce instruction should have an even number of operands but "
"sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Reduce instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
{
const auto reduce_operands = all_operands();
auto inputs = absl::MakeSpan(reduce_operands)
.subspan(0, reduce_operands.size() / 2);
auto init_values =
absl::MakeSpan(reduce_operands)
.subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction =
CreateReduce(shape, inputs, init_values,
std::vector<int64_t>(proto.dimensions().begin(),
proto.dimensions().end()),
computations(0));
}
break;
case HloOpcode::kSort: {
TF_RET_CHECK(proto.operand_ids_size() >= 1)
<< "Sort instruction should have at least 1 operand but has "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.dimensions().size() == 1)
<< "Sort instruction should have 1 dimension";
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Sort instruction should one called computation but sees "
<< proto.called_computation_ids_size();
auto sort_operands = all_operands();
instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
computations(0), proto.is_stable());
break;
}
case HloOpcode::kTranspose:
instruction =
CreateTranspose(shape, operands(0),
std::vector<int64_t>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kBroadcast:
instruction =
CreateBroadcast(shape, operands(0),
std::vector<int64_t>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kMap:
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Map instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateMap(shape, all_operands(), computations(0));
break;
case HloOpcode::kSlice: {
std::vector<int64_t> slice_starts, slice_limits, slice_strides;
for (const HloInstructionProto::SliceDimensions& slice_dimensions :
proto.slice_dimensions()) {
slice_starts.push_back(slice_dimensions.start());
slice_limits.push_back(slice_dimensions.limit());
slice_strides.push_back(slice_dimensions.stride());
}
instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
slice_strides);
break;
}
case HloOpcode::kConstant: {
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(
auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
instruction->shape(), shape))
<< instruction->shape().ToString(true) << " vs "
<< shape.ToString(true);
*instruction->mutable_shape() = shape;
} else {
instruction = std::make_unique<HloConstantInstruction>(shape);
}
break;
}
case HloOpcode::kFusion: {
// In the proto, fused computations are held exclusively within the
// HloInstructionProto and do not appear as an HloComputationProto within
// the HloModuleProto.
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
StringToFusionKind(proto.fusion_kind()));
// Find the fused computation and set its fusion instruction.
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Expect 1 called computation for fusion instruction but sees "
<< proto.called_computation_ids_size();
const int64_t fusion_id = proto.called_computation_ids(0);
auto* fused_computation =
tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
<< "No fusion computation with id " << fusion_id;
instruction =
CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
break;
}
case HloOpcode::kRng:
instruction = CreateRng(shape, proto.distribution(), all_operands());
break;
case HloOpcode::kRngBitGenerator:
instruction =
CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
break;
case HloOpcode::kRngGetAndUpdateState:
instruction = CreateRngGetAndUpdateState(shape, proto.delta());
break;
case HloOpcode::kParameter:
instruction =
CreateParameter(proto.parameter_number(), shape, proto.name());
if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
instruction->set_parameter_replicated_at_leaf_buffers(
proto.parameter_replication().replicated_at_leaf_buffers());
}
break;
case HloOpcode::kGetTupleElement:
instruction =
CreateGetTupleElement(shape, operands(0), proto.tuple_index());
break;
case HloOpcode::kReducePrecision:
instruction = CreateReducePrecision(
shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
break;
case HloOpcode::kInfeed: {
TF_RET_CHECK(shape.IsTuple() &&
(ShapeUtil::TupleElementCount(shape) == 2))
<< "Infeed should have a tuple shape with 2 operands, but has: "
<< shape;
const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed: {
Shape outfeed_shape(proto.outfeed_shape());
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
proto.outfeed_config());
break;
}
case HloOpcode::kAllGather:
case HloOpcode::kAllGatherStart: {
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "AllGather cannot have more than 1 all-gather dimensions";
int64_t all_gather_dimension = proto.dimensions(0);
if (opcode == HloOpcode::kAllGather) {
instruction = CreateAllGather(
shape, all_operands(), all_gather_dimension,
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
proto.constrain_layout(), channel_id,
proto.use_global_device_ids());
} else {
instruction = CreateAllGatherStart(
shape, all_operands(), all_gather_dimension,
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
proto.constrain_layout(), channel_id,
proto.use_global_device_ids());
}
break;
}
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kReduceScatter: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "AllReduce should have 1 called computation but sees "
<< proto.called_computation_ids_size();
TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
<< "AllReduce cannot have both channel_id() and all_reduce_id()";
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
if (proto.all_reduce_id() > 0) {
channel_id = proto.all_reduce_id();
}
std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(),
proto.replica_groups().end());
if (opcode == HloOpcode::kAllReduce) {
instruction =
CreateAllReduce(shape, all_operands(), computations(0),
replica_groups, proto.constrain_layout(),
channel_id, proto.use_global_device_ids());
} else if (opcode == HloOpcode::kReduceScatter) {
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "ReduceScatter cannot have more than 1 scatter dimensions";
int64_t scatter_dimension = proto.dimensions(0);
instruction = CreateReduceScatter(
shape, all_operands(), computations(0), replica_groups,
proto.constrain_layout(), channel_id, proto.use_global_device_ids(),
scatter_dimension);
} else {
instruction =
CreateAllReduceStart(shape, all_operands(), computations(0),
replica_groups, proto.constrain_layout(),
channel_id, proto.use_global_device_ids());
}
break;
}
case HloOpcode::kAllToAll: {
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
std::optional<int64_t> split_dimension;
if (proto.dimensions_size() > 0) {
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "AllToAll cannot have more than 1 dimension (split dimension)";
TF_RET_CHECK(all_operands().size() == 1)
<< "AllToAll must have a single operand when the split dimension "
"is specified";
split_dimension = proto.dimensions(0);
}
instruction = CreateAllToAll(
shape, all_operands(),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
/*constrain_layout=*/proto.constrain_layout(),
/*channel_id=*/channel_id, split_dimension);
break;
}
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart: {
TF_RET_CHECK(proto.operand_ids().size() == 1 ||
proto.operand_ids().size() == 4);
std::vector<std::pair<int64_t, int64_t>> source_target_pairs(
proto.source_target_pairs_size());
std::optional<int64_t> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
for (int i = 0; i < source_target_pairs.size(); ++i) {
source_target_pairs[i].first = proto.source_target_pairs(i).source();
source_target_pairs[i].second = proto.source_target_pairs(i).target();
}
if (proto.dynamic_slice_sizes_size() == 0) {
if (opcode == HloOpcode::kCollectivePermute) {
instruction = CreateCollectivePermute(
shape, operands(0), source_target_pairs, channel_id);
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = CreateCollectivePermuteStart(
shape, operands(0), source_target_pairs, channel_id);
} else {
LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
<< "but got " << HloOpcodeString(opcode);
}
} else {
std::vector<std::vector<int64_t>> slice_sizes;
HloInstruction* input = operands(0);
HloInstruction* input_start_indices = operands(2);
if (input->shape().IsTuple() &&
input->shape().tuple_shapes_size() > 1) {
slice_sizes.resize(input->shape().tuple_shapes_size());
} else {
slice_sizes.resize(1);
}
int proto_index = 0;
if (input->shape().IsTuple()) {
if (input_start_indices->shape()
.tuple_shapes(0)
.tuple_shapes(0)
.IsArray()) {
slice_sizes.resize(input->shape().tuple_shapes_size());
for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
slice_sizes[i].resize(
input->shape().tuple_shapes(i).dimensions_size());
for (int j = 0;
j < input->shape().tuple_shapes(i).dimensions_size(); ++j) {
CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
proto_index += 1;
}
}
} else {
slice_sizes.resize(
input->shape().tuple_shapes_size() *
ShapeUtil::TupleElementCount(
input_start_indices->shape().tuple_shapes(0)));
int slice_sizes_count = 0;
for (int i = 0; i < input->shape().tuple_shapes_size(); ++i) {
for (int j = 0;
j < ShapeUtil::TupleElementCount(
input_start_indices->shape().tuple_shapes(i));
++j) {
slice_sizes[slice_sizes_count].resize(
input->shape().tuple_shapes(i).rank());
for (int k = 0; k < input->shape().tuple_shapes(i).rank();
++k) {
CHECK_GE(proto.dynamic_slice_sizes_size(), proto_index);
slice_sizes[slice_sizes_count][k] =
proto.dynamic_slice_sizes(proto_index);
proto_index += 1;
}
slice_sizes_count += 1;
}
}
}
} else {
slice_sizes.resize(
ShapeUtil::TupleElementCount(input_start_indices->shape()));
if (input_start_indices->shape().tuple_shapes(0).IsTuple()) {
for (int i = 0;
i < ShapeUtil::TupleElementCount(input_start_indices->shape());
++i) {
slice_sizes[i].resize(input->shape().dimensions_size());
for (int j = 0; j < input->shape().dimensions_size(); ++j) {
slice_sizes[i][j] = proto.dynamic_slice_sizes(proto_index);
proto_index += 1;
}
}
} else {
slice_sizes.resize(1);
slice_sizes[0].resize(input->shape().dimensions_size());
for (int j = 0; j < input->shape().dimensions_size(); ++j) {
slice_sizes[0][j] = proto.dynamic_slice_sizes(proto_index);
proto_index += 1;
}
}
}
if (opcode == HloOpcode::kCollectivePermute) {
instruction = CreateCollectivePermute(
shape, operands(0), operands(1), operands(2), operands(3),
source_target_pairs, slice_sizes, channel_id);
} else if (opcode == HloOpcode::kCollectivePermuteStart) {
instruction = CreateCollectivePermuteStart(
shape, operands(0), operands(1), operands(2), operands(3),
source_target_pairs, slice_sizes, channel_id);
} else {
LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
<< "but got " << HloOpcodeString(opcode);
}
}
break;
}
case HloOpcode::kReplicaId: {
instruction = CreateReplicaId(shape);
break;
}
case HloOpcode::kPartitionId: {
instruction = CreatePartitionId(shape);
break;
}
case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = CreateConvolve(
shape, operands(0), operands(1),
std::max<int64_t>(proto.feature_group_count(), 1),
std::max<int64_t>(proto.batch_group_count(), 1), proto.window(),
proto.convolution_dimension_numbers(), precision_config);
break;
}
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
<< "Reduce window should have an even number of operands but "
"sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "ReduceWindow should have 1 called computation but sees "
<< proto.called_computation_ids_size();
{
const auto reduce_operands = all_operands();
auto inputs = absl::MakeSpan(reduce_operands)
.subspan(0, reduce_operands.size() / 2);
auto init_values =
absl::MakeSpan(reduce_operands)
.subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction = CreateReduceWindow(shape, inputs, init_values,
proto.window(), computations(0));
}
break;
case HloOpcode::kSelectAndScatter:
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
<< "SelectAndScatter should have 2 called computations but sees "
<< proto.called_computation_ids_size();
instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
proto.window(), operands(1),
operands(2), computations(1));
break;
case HloOpcode::kCustomCall: {
if (proto.constrain_layout()) {
// A proto RepeatedPtrField cannot be converted to a Span (it is a
// vector of pointers essentially) so create a vector of shapes to pass
// in.
std::vector<Shape> operand_shapes;
const auto& operand_shapes_with_layout =
proto.operand_shapes_with_layout();
operand_shapes.reserve(operand_shapes_with_layout.size());
for (const ShapeProto& shape_proto : operand_shapes_with_layout) {
operand_shapes.emplace_back(shape_proto);
}
instruction =
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
operand_shapes, proto.backend_config());
} else {
if (proto.called_computation_ids_size() == 1) {
instruction = CreateCustomCall(shape, all_operands(), computations(0),
proto.custom_call_target(),
proto.backend_config());
} else if (proto.called_computation_ids_size() > 1) {
instruction = CreateCustomCall(
shape, all_operands(), all_computations(),
proto.custom_call_target(), proto.backend_config());
} else {
instruction = CreateCustomCall(shape, all_operands(),
proto.custom_call_target(),
proto.backend_config());
}
}
auto custom_call_instr =
Cast<HloCustomCallInstruction>(instruction.get());
if (proto.has_window()) {
custom_call_instr->set_window(proto.window());
}
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(
auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
custom_call_instr->set_literal(std::move(literal));
}
if (proto.has_convolution_dimension_numbers()) {
custom_call_instr->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
}
custom_call_instr->set_feature_group_count(std::max(
static_cast<int64_t>(proto.feature_group_count()), int64_t{1}));
custom_call_instr->set_batch_group_count(std::max(
static_cast<int64_t>(proto.batch_group_count()), int64_t{1}));
custom_call_instr->set_custom_call_has_side_effect(
proto.custom_call_has_side_effect());
custom_call_instr->set_padding_type(proto.padding_type());
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
*custom_call_instr->mutable_precision_config() = precision_config;
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_to_operand_aliasing;
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
output_to_operand_aliasing.emplace_back(
ShapeIndex(aliasing.output_shape_index().begin(),
aliasing.output_shape_index().end()),
std::pair<int64_t, ShapeIndex>{
aliasing.operand_index(),
ShapeIndex(aliasing.operand_shape_index().begin(),
aliasing.operand_shape_index().end())});
}
custom_call_instr->set_output_to_operand_aliasing(
std::move(output_to_operand_aliasing));
custom_call_instr->set_custom_call_schedule(proto.custom_call_schedule());
custom_call_instr->set_api_version(proto.custom_call_api_version());
break;
}
case HloOpcode::kPad:
TF_RET_CHECK(proto.has_padding_config());
instruction =
CreatePad(shape, operands(0), operands(1), proto.padding_config());
break;
case HloOpcode::kDynamicSlice: {
std::vector<int64_t> slice_sizes(proto.dynamic_slice_sizes_size());
absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
TF_RET_CHECK(proto.operand_ids_size() >= 1)
<< "DynamicSlice instruction should have at least 1 operands but "
"sees "
<< proto.operand_ids_size();
// TODO(b/118437727): Old form, make the check unconditional.
if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
auto expected_operands = 1 + operands(0)->shape().rank();
TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
<< "DynamicSlice instruction should have " << expected_operands
<< " operands, but has " << proto.operand_ids_size();
}
const auto& operand_vector = all_operands();
instruction = CreateDynamicSlice(
shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
slice_sizes);
break;
}
case HloOpcode::kDynamicUpdateSlice: {
TF_RET_CHECK(proto.operand_ids_size() >= 2)
<< "DynamicUpdateSlice instruction should have at least 2 operands "
"but sees "
<< proto.operand_ids_size();
// TODO(b/118437727): Old form, make the check unconditional.
if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
auto expected_operands = 2 + operands(0)->shape().rank();
TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
<< "DynamicUpdateSlice instruction should have "
<< expected_operands << " operands, but has "
<< proto.operand_ids_size();
}
const auto& operand_vector = all_operands();
instruction =
CreateDynamicUpdateSlice(shape, operands(0), operands(1),
absl::MakeSpan(operand_vector).subspan(2));
break;
}
case HloOpcode::kGather: {
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
auto gather_dimension_numbers = std::make_unique<GatherDimensionNumbers>(
proto.gather_dimension_numbers());
std::vector<int64_t> gather_slice_sizes;
const auto& slice_sizes = proto.gather_slice_sizes();
gather_slice_sizes.reserve(slice_sizes.size());
for (int64_t bound : slice_sizes) {
gather_slice_sizes.push_back(bound);
}
instruction = CreateGather(shape, operands(0), operands(1),
*gather_dimension_numbers, gather_slice_sizes,
proto.indices_are_sorted());
break;
}
case HloOpcode::kScatter: {
TF_RET_CHECK(proto.has_scatter_dimension_numbers())
<< "Scatter instruction should have ScatterDimensionNumbers set.";
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
auto scatter_dimension_numbers =
std::make_unique<ScatterDimensionNumbers>(
proto.scatter_dimension_numbers());
auto operands = all_operands();
auto operand_span = absl::MakeConstSpan(operands);
auto input_count = operands.size() / 2;
instruction =
CreateScatter(shape, operand_span.first(input_count),
operands[input_count], operand_span.last(input_count),
computations(0), *scatter_dimension_numbers,
proto.indices_are_sorted(), proto.unique_indices());
break;
}
case HloOpcode::kIota:
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(shape, proto.dimensions(0));
break;
case HloOpcode::kDot: {
TF_RET_CHECK(proto.has_dot_dimension_numbers())
<< "Dot instruction should have dot_dimension_numbers.";
PrecisionConfig precision_config = proto.precision_config();
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = std::make_unique<HloDotInstruction>(
shape, operands(0), operands(1), proto.dot_dimension_numbers(),
precision_config);
break;
}
case HloOpcode::kDomain: {
std::shared_ptr<const HloSharding> entry_hlo_sharding;
std::shared_ptr<const HloSharding> exit_hlo_sharding;
if (proto.has_domain_entry_sharding()) {
TF_ASSIGN_OR_RETURN(
HloSharding sharding,
HloSharding::FromProto(proto.domain_entry_sharding()));
entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
}
if (proto.has_domain_exit_sharding()) {
TF_ASSIGN_OR_RETURN(
HloSharding sharding,
HloSharding::FromProto(proto.domain_exit_sharding()));
exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
}
instruction = std::make_unique<HloDomainInstruction>(
shape, operands(0),
std::make_unique<ShardingMetadata>(entry_hlo_sharding),
std::make_unique<ShardingMetadata>(exit_hlo_sharding));
break;
}
case HloOpcode::kGetDimensionSize:
TF_RET_CHECK(proto.dimensions_size() == 1);
instruction =
CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
break;
case HloOpcode::kSetDimensionSize:
TF_RET_CHECK(proto.dimensions_size() == 1);
instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
proto.dimensions(0));
break;
case HloOpcode::kReshape: {
int64_t inferred_dimension = -1;
if (!proto.dimensions().empty()) {
inferred_dimension = proto.dimensions()[0];
}
TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
ShapeUtil::ElementsIn(shape) ==
ShapeUtil::ElementsIn(operands(0)->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operands(0)->shape());
instruction = CreateReshape(shape, operands(0), inferred_dimension);
break;
}
case HloOpcode::kDynamicReshape: {
TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
ShapeUtil::ElementsIn(shape) ==
ShapeUtil::ElementsIn(operands(0)->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operands(0)->shape());
const auto& operand_vector = all_operands();
instruction = CreateDynamicReshape(
shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
break;
}
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (const int64_t operand_id : proto.operand_ids()) {
instruction->AppendOperand(instruction_map.at(operand_id));
}
if (instruction->opcode() != HloOpcode::kFusion) {
if (instruction->opcode() == HloOpcode::kCall) {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Call should have 1 called computation but has "
<< proto.called_computation_ids_size();
}
for (const int64_t computation_id : proto.called_computation_ids()) {
instruction->called_computations_.push_back(
computation_map.at(computation_id));
}
}
TF_RET_CHECK(!proto.has_precision_config())
<< instruction->opcode() << proto.DebugString();
TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
for (const int64_t predecessor_id : proto.control_predecessor_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
<< "No instruction with id " << predecessor_id;
TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
->AddControlDependencyTo(instruction.get()));
}
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
instruction->outer_dimension_partitions_.assign(
proto.outer_dimension_partitions().begin(),
proto.outer_dimension_partitions().end());
TF_RET_CHECK(proto.id() >= 0)
<< "Instruction with negative id: " << proto.id();
TF_RET_CHECK(proto.id() <= INT_MAX)
<< "Instruction with id > INT_MAX: " << proto.id();
instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
HloSharding::FromProto(proto.sharding()));
instruction->set_sharding(sharding);
}
if (proto.has_frontend_attributes()) {
instruction->set_frontend_attributes(proto.frontend_attributes());
}
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64_t parameter_number, const Shape& shape, const std::string& name) {
return std::make_unique<HloParameterInstruction>(parameter_number, shape,
name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
Literal literal) {
return std::make_unique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
const Shape& shape, int64_t iota_dimension) {
return std::make_unique<HloIotaInstruction>(shape, iota_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64_t index) {
return std::make_unique<HloGetTupleElementInstruction>(shape, operand, index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(HloInstruction* operand, int64_t index) {
return std::make_unique<HloGetTupleElementInstruction>(
operand->shape().tuple_shapes(index), operand, index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters) {
return std::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64_t delta) {
return std::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
RandomAlgorithm algorithm) {
return std::make_unique<HloRngBitGeneratorInstruction>(shape, state,
algorithm);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
absl::Span<HloInstruction* const> operands) {
if (opcode == HloOpcode::kCopy) {
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!shape.IsOpaque());
}
auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
// Only certain opcodes are supported with CreateUnary: opcodes of unary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kAbs:
case HloOpcode::kAllGatherDone:
case HloOpcode::kAllReduceDone:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kOptimizationBarrier:
case HloOpcode::kClz:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
break;
default:
LOG(FATAL) << "Invalid unary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {operand});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
// Only certain opcodes are supported with CreateBinary: opcodes of binary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
break;
default:
LOG(FATAL) << "Invalid binary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {lhs, rhs});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs, HloInstruction* ehs) {
// Only certain opcodes are supported with CreateTernary: opcodes of ternary
// instructions with no auxiliary fields.
switch (opcode) {
case HloOpcode::kClamp:
case HloOpcode::kSelect:
break;
default:
LOG(FATAL) << "Invalid ternary instruction opcode "
<< HloOpcodeString(opcode);
}
return CreateNary(shape, opcode, {lhs, rhs, ehs});
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
const Shape& shape, HloOpcode opcode,
absl::Span<HloInstruction* const> operands) {
CHECK_EQ(HloOpcode::kTuple, opcode);
return CreateNary(shape, opcode, operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation) {
return std::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64_t feature_group_count, int64_t batch_group_count,
const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
return std::make_unique<HloConvolutionInstruction>(
shape, lhs, rhs, feature_group_count, batch_group_count, window,
dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
absl::Span<const int64_t> fft_length) {
return std::make_unique<HloFftInstruction>(shape, operand, fft_type,
fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncStart(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* async_computation, std::optional<int64_t> async_group_id) {
return std::make_unique<HloAsyncInstruction>(HloOpcode::kAsyncStart, shape,
operands, async_computation,
async_group_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncUpdate(
const Shape& shape, HloInstruction* operand,
HloComputation* async_computation, std::optional<int64_t> async_group_id) {
return std::make_unique<HloAsyncInstruction>(HloOpcode::kAsyncUpdate, shape,
operand, async_computation,
async_group_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAsyncDone(
const Shape& shape, HloInstruction* operand,
HloComputation* async_computation, std::optional<int64_t> async_group_id) {
return std::make_unique<HloAsyncInstruction>(
HloOpcode::kAsyncDone, shape, operand, async_computation, async_group_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
const Shape& shape, HloInstruction* operand,
bool is_cross_program_prefetch) {
return std::make_unique<HloCopyStartInstruction>(shape, operand,
is_cross_program_prefetch);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction, std::optional<Comparison::Type> type) {
return std::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
type);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
HloInstruction* b,
const TriangularSolveOptions& options) {
return std::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
return std::make_unique<HloCholeskyInstruction>(shape, a, options);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
return std::make_unique<HloDotInstruction>(shape, lhs, rhs, dimension_numbers,
precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
return std::make_unique<HloReducePrecisionInstruction>(
shape, operand, exponent_bits, mantissa_bits);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
const Shape& shape, absl::Span<HloInstruction* const> operands,
int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
bool constrain_layout, const std::optional<int64_t>& channel_id,
bool use_global_device_ids) {
return std::make_unique<HloAllGatherInstruction>(
HloOpcode::kAllGather, shape, operands, all_gather_dimension,
replica_groups, constrain_layout, channel_id, use_global_device_ids);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateAllGatherStart(
const Shape& shape, absl::Span<HloInstruction* const> operands,
int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups,
bool constrain_layout, const std::optional<int64_t>& channel_id,
bool use_global_device_ids) {
return std::make_unique<HloAllGatherInstruction>(
HloOpcode::kAllGatherStart, shape, operands, all_gather_dimension,
replica_groups, constrain_layout, channel_id, use_global_device_ids);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
return std::make_unique<HloAllReduceInstruction>(
HloOpcode::kAllReduce, shape, operands, reduce_computation,
replica_groups, constrain_layout, channel_id, use_global_device_ids);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateReduceScatter(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
const std::optional<int64_t>& channel_id, bool use_global_device_ids,
int64_t scatter_dimension) {
return std::make_unique<HloReduceScatterInstruction>(
shape, operands, reduce_computation, replica_groups, constrain_layout,
channel_id, use_global_device_ids, scatter_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateAllReduceStart(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
const std::optional<int64_t>& channel_id, bool use_global_device_ids) {
return std::make_unique<HloAllReduceInstruction>(
HloOpcode::kAllReduceStart, shape, operands, reduce_computation,
replica_groups, constrain_layout, channel_id, use_global_device_ids);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout,
const std::optional<int64_t>& channel_id,
const std::optional<int64_t>& split_dimension) {
return std::make_unique<HloAllToAllInstruction>(
shape, operands, replica_groups, constrain_layout, channel_id,
split_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermute(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<int64_t>& channel_id) {
return std::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermute(
const Shape& shape, HloInstruction* input, HloInstruction* output,
HloInstruction* input_start_indices, HloInstruction* output_start_indices,
absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
absl::Span<const std::vector<int64_t>> slice_sizes,
const std::optional<int64_t>& channel_id) {
return std::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermute, shape, input, output, input_start_indices,
output_start_indices, source_target_pairs, slice_sizes, channel_id);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<int64_t>& channel_id) {
return std::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
channel_id);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCollectivePermuteStart(
const Shape& shape, HloInstruction* input, HloInstruction* output,
HloInstruction* input_start_indices, HloInstruction* output_start_indices,
absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs,
absl::Span<const std::vector<int64_t>> slice_sizes,
const std::optional<int64_t>& channel_id) {
return std::make_unique<HloCollectivePermuteInstruction>(
HloOpcode::kCollectivePermuteStart, shape, input, output,
input_start_indices, output_start_indices, source_target_pairs,
slice_sizes, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
const Shape& shape) {
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
<< "HloInstruction replica-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
const Shape& shape) {
CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
<< "HloInstruction partition-id must have a shape of u32[], but "
<< shape.ToString() << " is specified";
return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const std::string& config) {
return std::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, absl::string_view outfeed_config) {
return std::make_unique<HloOutfeedInstruction>(outfeed_shape, operand,
token_operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
return std::make_unique<HloSendInstruction>(operand, token, channel_id,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand, bool is_host_transfer) {
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
return std::make_unique<HloSendDoneInstruction>(send_operand,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64_t channel_id,
bool is_host_transfer) {
return std::make_unique<HloRecvInstruction>(shape, token, channel_id,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand, bool is_host_transfer) {
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
return std::make_unique<HloRecvDoneInstruction>(recv_operand,
is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64_t> dimensions) {
return std::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
absl::Span<HloInstruction* const> operands) {
CHECK(!operands.empty());
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
return absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateAddDependency(HloInstruction* data_operand,
HloInstruction* token_operand) {
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
instruction->AppendOperand(data_operand);
instruction->AppendOperand(token_operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
// Body comes before condition computation in the vector.
instruction->called_computations_.push_back(body);
instruction->called_computations_.push_back(condition);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
const Shape& shape, HloInstruction* pred,
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
// In called_computations_, the index of true_computation must be 0 and that
// of false computation must be 1, as defined by kTrueComputationIndex and
// kFalseComputationIndex.
instruction->called_computations_.push_back(true_computation);
instruction->called_computations_.push_back(false_computation);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
const Shape& shape, HloInstruction* branch_index,
absl::Span<HloComputation* const> branch_computations,
absl::Span<HloInstruction* const> branch_computation_args) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(branch_index);
CHECK_EQ(branch_computations.size(), branch_computation_args.size());
for (int i = 0; i < branch_computations.size(); ++i) {
instruction->called_computations_.push_back(branch_computations[i]);
instruction->AppendOperand(branch_computation_args[i]);
}
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64_t> start_indices,
absl::Span<const int64_t> limit_indices,
absl::Span<const int64_t> strides) {
return std::make_unique<HloSliceInstruction>(shape, operand, start_indices,
limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand,
absl::Span<HloInstruction* const> start_indices,
absl::Span<const int64_t> slice_sizes) {
return std::make_unique<HloDynamicSliceInstruction>(
shape, operand, start_indices, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateDynamicUpdateSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* update,
absl::Span<HloInstruction* const> start_indices) {
return std::make_unique<HloDynamicUpdateSliceInstruction>(
shape, operand, update, start_indices);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, absl::Span<HloInstruction* const> operands,
int64_t dimension) {
return std::make_unique<HloConcatenateInstruction>(shape, operands,
dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
const Shape& shape, HloInstruction* operand) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
const Shape& shape, HloInstruction* operand) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
absl::Span<const int64_t> dimensions_to_reduce,
HloComputation* reduce_computation) {
auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values,
absl::Span<const int64_t> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
return std::make_unique<HloReduceInstruction>(
shape, all_args, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* tuple_of_instructions,
absl::Span<HloInstruction* const> init_values,
absl::Span<const int64_t> dimensions_to_reduce,
HloComputation* reduce_computation) {
if (!tuple_of_instructions->shape().IsTuple()) {
CHECK_EQ(init_values.size(), 1)
<< "The first input has to be a tuple, or the number of init values "
"has to be one.";
return CreateReduce(shape, tuple_of_instructions, init_values[0],
dimensions_to_reduce, reduce_computation);
}
absl::InlinedVector<HloInstruction*, 4> inputs;
for (int idx = 0; idx < tuple_of_instructions->shape().tuple_shapes_size();
idx++) {
std::unique_ptr<HloInstruction> gte =
HloInstruction::CreateGetTupleElement(tuple_of_instructions, idx);
inputs.push_back(
tuple_of_instructions->parent()->AddInstruction(std::move(gte)));
}
return CreateReduce(shape, inputs, init_values, dimensions_to_reduce,
reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
return std::make_unique<HloReduceWindowInstruction>(
shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values, const Window& window,
HloComputation* reduce_computation) {
return std::make_unique<HloReduceWindowInstruction>(
shape, operands, init_values, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* operand,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64_t feature_index) {
return std::make_unique<HloBatchNormTrainingInstruction>(
shape, operand, scale, offset, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64_t feature_index) {
return std::make_unique<HloBatchNormInferenceInstruction>(
shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* scale, HloInstruction* mean,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64_t feature_index) {
return std::make_unique<HloBatchNormGradInstruction>(
shape, operand, scale, mean, variance, grad_output, epsilon,
feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
return std::make_unique<HloSelectAndScatterInstruction>(
shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64_t> broadcast_dimensions) {
return std::make_unique<HloBroadcastInstruction>(shape, operand,
broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetDimensionSize(const Shape& shape,
HloInstruction* operand,
int64_t dimension) {
return std::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateSetDimensionSize(const Shape& shape,
HloInstruction* operand,
HloInstruction* val, int64_t dimension) {
return std::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBroadcastSequence(
const Shape& output_shape, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
adder) {
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
operand->shape().rank() == output_shape.rank());
Shape broadcast_shape = ShapeUtil::ChangeElementType(
output_shape, operand->shape().element_type());
// Do explicit broadcast for scalar.
if (ShapeUtil::IsScalar(operand->shape())) {
auto broadcast =
HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
// Do explicit broadcast for degenerate broadcast.
std::vector<int64_t> broadcast_dimensions;
std::vector<int64_t> reshaped_dimensions;
for (int i = 0; i < operand->shape().rank(); i++) {
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand->shape().dimensions(i));
} else {
CHECK_EQ(operand->shape().dimensions(i), 1)
<< "An explicit broadcast sequence requires the broadcasted "
"dimensions to be trivial; operand: "
<< operand->ToString() << "; output_shape: " << output_shape;
}
}
// Eliminate the size one dimensions.
HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(),
reshaped_dimensions),
operand));
reshaped_operand->set_metadata(operand->metadata());
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
// Broadcast 'reshape' up to the larger size.
auto broadcast = HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions);
broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
return std::make_unique<HloPadInstruction>(shape, operand, padding_value,
padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
return std::make_unique<HloReshapeInstruction>(shape, operand,
inferred_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateDynamicReshape(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(data_operand[0].shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
CHECK_EQ(shape.rank(), dim_sizes.size());
return std::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
dim_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64_t> dimensions) {
return std::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64_t dimension,
absl::Span<HloInstruction* const> operands, HloComputation* compare,
bool is_stable) {
return std::make_unique<HloSortInstruction>(shape, dimension, operands,
compare, is_stable);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return std::make_unique<HloFusionInstruction>(shape, fusion_kind, fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation) {
return std::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
fusion_computation);
}
void HloInstruction::set_single_sharding(const HloSharding& sharding) {
CHECK(!sharding.IsTuple()) << sharding;
if (shape().IsTuple()) {
set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
} else {
set_sharding(sharding);
}
}
void HloInstruction::SetupDerivedInstruction(
HloInstruction* derived_instruction) const {
if (sharding_ != nullptr &&
ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
// Only copy sharding if the tuple tree shape of the two instruction is
// compatible because copying it between differently shaped instructions
// can produce invalid shardings.
derived_instruction->set_sharding(*sharding_);
} else {
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
derived_instruction->set_frontend_attributes(frontend_attributes_);
}
bool HloInstruction::IsRoot() const {
return parent_ != nullptr && this == parent_->root_instruction();
}
bool HloInstruction::HasSideEffectNoRecurse() const {
switch (opcode_) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kAllReduceStart:
case HloOpcode::kAllReduceDone:
case HloOpcode::kAllGatherStart:
case HloOpcode::kAllGatherDone:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kCollectivePermuteDone:
return true;
case HloOpcode::kAllReduce:
return channel_id().has_value() ||
Cast<HloAllReduceInstruction>(this)->constrain_layout();
case HloOpcode::kAllToAll:
return Cast<HloAllToAllInstruction>(this)->constrain_layout();
case HloOpcode::kCustomCall:
return Cast<HloCustomCallInstruction>(this)
->custom_call_has_side_effect();
default:
return false;
}
}
bool HloInstruction::HasSideEffect() const {
if (HasSideEffectNoRecurse()) {
return true;
}
// Check if any of the called computations has a side effect.
for (const auto& computation : called_computations()) {
if (computation->HasSideEffect()) {
return true;
}
}
return false;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
const Shape& shape, HloInstruction* called_computation_root) {
return std::make_unique<HloCallInstruction>(shape, called_computation_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation) {
return std::make_unique<HloCallInstruction>(shape, operands, computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, std::string opaque,
CustomCallApiVersion api_version) {
return std::make_unique<HloCustomCallInstruction>(
shape, operands, custom_call_target, std::move(opaque), api_version);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* to_apply, absl::string_view custom_call_target,
std::string opaque, CustomCallApiVersion api_version) {
return std::make_unique<HloCustomCallInstruction>(
shape, operands, to_apply, custom_call_target, std::move(opaque),
api_version);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations,
absl::string_view custom_call_target, std::string opaque,
CustomCallApiVersion api_version) {
return std::make_unique<HloCustomCallInstruction>(
shape, operands, called_computations, custom_call_target,
std::move(opaque), api_version);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target,
absl::Span<const Shape> operand_shapes_with_layout, std::string opaque,
CustomCallApiVersion api_version) {
return std::make_unique<HloCustomCallInstruction>(
shape, operands, custom_call_target, std::move(opaque),
operand_shapes_with_layout, api_version);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
absl::Span<HloInstruction* const> elements) {
std::vector<const Shape*> element_shapes;
element_shapes.reserve(elements.size());
for (auto element : elements) {
element_shapes.push_back(&element->shape());
}
Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(element_shapes);
return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
return std::make_unique<HloGatherInstruction>(shape, operand, start_indices,
gather_dim_numbers, slice_sizes,
indices_are_sorted);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
bool unique_indices) {
return absl::WrapUnique(new HloScatterInstruction(
shape, {operand, scatter_indices, updates}, update_computation,
scatter_dim_numbers, indices_are_sorted, unique_indices));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloInstruction* scatter_indices, absl::Span<HloInstruction* const> updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
bool unique_indices) {
absl::InlinedVector<HloInstruction*, 3> args;
args.reserve(operands.size() + updates.size() + 1);
absl::c_copy(operands, std::back_inserter(args));
args.push_back(scatter_indices);
absl::c_copy(updates, std::back_inserter(args));
return std::make_unique<HloScatterInstruction>(
shape, args, update_computation, scatter_dim_numbers, indices_are_sorted,
unique_indices);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
return std::make_unique<HloDomainInstruction>(
shape, operand, std::move(operand_side_metadata),
std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
std::unique_ptr<HloInstruction> clone;
// Explicitly call the factory for the instruction type. This is more robust
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
switch (opcode_) {
// Ops migrated to subclasses.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kAsyncStart:
case HloOpcode::kAsyncUpdate:
case HloOpcode::kAsyncDone:
case HloOpcode::kCopyStart:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kAllGather:
case HloOpcode::kAllGatherStart:
case HloOpcode::kAllReduce:
case HloOpcode::kReduceScatter:
case HloOpcode::kAllReduceStart:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
case HloOpcode::kDot:
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kSetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kAllGatherDone:
case HloOpcode::kAllReduceDone:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kOptimizationBarrier:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
break;
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(new_operands.size(), 2);
clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
break;
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
CHECK_EQ(new_operands.size(), 3);
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
new_operands[2]);
break;
// Other supported ops.
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
break;
case HloOpcode::kBitcastConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
case HloOpcode::kDynamicUpdateSlice:
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
new_operands.subspan(2));
break;
case HloOpcode::kTuple:
clone = CreateTuple(new_operands);
*clone->mutable_shape() = shape;
break;
case HloOpcode::kWhile:
CHECK_EQ(new_operands.size(), 1);
clone =
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), branch_count() + 1);
clone = CreateConditional(shape, new_operands[0],
absl::MakeSpan(branch_computations()),
new_operands.subspan(1));
break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
} else {
clone = CreateAfterAll(new_operands);
}
break;
case HloOpcode::kAddDependency:
CHECK_EQ(new_operands.size(), 2);
clone = CreateAddDependency(new_operands[0], new_operands[1]);
break;
case HloOpcode::kReplicaId:
CHECK_EQ(new_operands.size(), 0);
clone = CreateReplicaId(shape);
break;
case HloOpcode::kPartitionId:
CHECK_EQ(new_operands.size(), 0);
clone = CreatePartitionId(shape);
break;
}
// SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_outer_dimension_partitions(outer_dimension_partitions_);
clone->backend_config_ = backend_config_.Clone();
// The new instruction's name will be uniquified when it's added to a
// computation.
clone->SetAndSanitizeName(name());
if (context != nullptr) {
context->MapInstruction(this, clone.get());
clone->ReplaceCalledComputations([&](HloComputation* callee) {
return callee->parent() != context->module()
? context->module()->DeepCloneComputation(callee, context)
: callee;
});
}
return clone;
}
void HloInstruction::DetachFromOperandsAndUsers() {
if (cleaned_up_) {
return;
}
cleaned_up_ = true;
// Detach from operands. An instruction may be repeated as an operand. To
// avoid calling RemoveUser twice on the same operand, check before remove.
for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) {
HloInstruction* operand = operands_[operand_num];
if (operand == nullptr) {
continue;
}
if (operand->user_map_.find(this) != operand->user_map_.end()) {
operand->RemoveUser(this);
}
operands_[operand_num] = nullptr;
}
// Update users. Set `nullptr` to the corresponding operand slot for users.
for (auto& user : this->users()) {
for (int i = 0; i < user->operand_count(); ++i) {
if (user->operands_[i] == this) {
user->operands_[i] = nullptr;
}
}
}
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
const Shape& shape, const std::string& suffix,
HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
CloneWithNewOperands(shape, operands_, context);
if (suffix.empty()) {
clone->name_ = name();
} else {
// If an instruction is cloned multiple times avoid names like
// foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
// suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
// clone of foo.suffix2 is named foo.suffix3 and so on.
const std::string dot_suffix = "." + suffix;
size_t index = name().rfind(dot_suffix);
if (index == std::string::npos) {
// Existing name does not include ".suffix".
clone->name_ = name() + dot_suffix;
} else {
// Existing name includes ".suffix". Determine if substring after
// ".suffix" is numeric and should be replaced with an incremented number.
std::string after_suffix = name().substr(index + dot_suffix.size());
if (after_suffix.empty()) {
// Existing name ends in ".suffix". New name should end in ".suffix2".
clone->name_ = name() + "2";
} else {
// If names ends with .suffix[0-9]+ then replace with a suffix with the
// numeric value incremented.
int64_t numeric_suffix;
if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
clone->name_ =
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
} else {
// Substring after ".suffix" is non-numeric.
clone->name_ = name() + dot_suffix;
}
}
}
}
return clone;
}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const std::string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
CloneWithNewShape(shape_, suffix, context);
return clone;
}
std::pair<const HloInstruction*, ShapeIndex>
HloInstruction::LatestNonGteAncestorAndIndex() const {
const HloInstruction* hlo = this;
ShapeIndex index;
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
index.push_back(hlo->tuple_index());
hlo = hlo->operand(0);
}
// We built up index in the reverse order from what we want.
std::reverse(index.begin(), index.end());
return {hlo, index};
}
const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
const HloInstruction* hlo = this;
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
hlo = hlo->operand(0);
}
return hlo;
}
const HloInstruction* HloInstruction::operand(int64_t i) const {
return operands_.at(i);
}
HloInstruction* HloInstruction::mutable_operand(int64_t i) {
CHECK(operands_[i] != nullptr);
return operands_.at(i);
}
int64_t HloInstruction::operand_index(const HloInstruction* target) const {
for (int64_t i = 0; i < operand_count(); ++i) {
if (target == operand(i)) {
return i;
}
}
LOG(FATAL) << "target was not an operand: " << target->ToString();
}
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
absl::flat_hash_set<const HloInstruction*> seen;
for (HloInstruction* operand : operands()) {
if (seen.insert(operand).second) {
unique.push_back(operand);
}
}
return unique;
}
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
if (!absl::c_linear_search(control_successors_, instruction)) {
control_successors_.push_back(instruction);
TF_RET_CHECK(
!absl::c_linear_search(instruction->control_predecessors_, this));
instruction->control_predecessors_.push_back(this);
}
return OkStatus();
}
Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
TF_RETURN_IF_ERROR(
EraseElementFromVector(&instruction->control_predecessors_, this));
return OkStatus();
}
Status HloInstruction::DropAllControlDeps() {
for (auto* ctrl_succ : control_successors_) {
TF_RETURN_IF_ERROR(
EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
}
for (auto* ctrl_pred : control_predecessors_) {
TF_RETURN_IF_ERROR(
EraseElementFromVector(&ctrl_pred->control_successors_, this));
}
control_successors_.clear();
control_predecessors_.clear();
return OkStatus();
}
Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
for (auto* ctrl_pred : inst->control_predecessors()) {
TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
}
for (auto* ctrl_succ : inst->control_successors()) {
TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
}
return OkStatus();
}
bool HloInstruction::IdenticalInternal(
const HloInstruction& other,
const std::function<bool(const HloInstruction*, const HloInstruction*)>&
eq_operands,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations,
bool layout_sensitive, bool ignore_channel_id_values,
bool ignore_commutative_operand_order) const {
// An instruction is always identical to itself.
if (this == &other) {
return true;
}
// Identical instruction must have the same opcode, shape, and identical
// operands.
if (opcode() != other.opcode()) {
return false;
}
if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
: ShapeUtil::Compatible(shape(), other.shape()))) {
return false;
}
if (operands().size() != other.operands().size()) {
return false;
}
// Check that operands are equal.
//
// Use an explicit loop rather than ContainerEquals, because copying around
// std::functions may be too expensive in some cases.
if (ignore_commutative_operand_order &&
HloOpcodeIsBinaryCommutative(opcode())) {
CHECK_EQ(operand_count(), 2);
if (!(eq_operands(operand(0), other.operand(0)) &&
eq_operands(operand(1), other.operand(1))) &&
!(eq_operands(operand(0), other.operand(1)) &&
eq_operands(operand(1), other.operand(0)))) {
return false;
}
} else {
for (size_t i = 0; i < operands().size(); ++i) {
if (!eq_operands(operand(i), other.operand(i))) {
return false;
}
}
}
if (backend_config_ != other.backend_config_) {
return false;
}
if (ignore_channel_id_values) {
if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
other, eq_computations);
}
}
return IdenticalSlowPath(other, eq_computations);
}
void HloInstruction::AppendOperand(HloInstruction* operand) {
if (operand->parent() != nullptr) {
DCHECK(!operand->parent()->IsMarkedAsDead(operand))
<< "Operand " << operand->name() << " is already marked dead";
}
operands_.push_back(operand);
operand->AddUser(this);
}
void HloInstruction::RemoveOperandsAtAscendingIndices(
absl::Span<const int> ascending_indices) {
if (ascending_indices.empty()) {
return;
}
int next_index = 0;
int removed_count = 0;
for (int to_remove : ascending_indices) {
while (next_index < to_remove) {
operands_[next_index - removed_count] = operands_[next_index];
++next_index;
}
CHECK_LT(to_remove, operands_.size());
++removed_count;
++next_index;
}
while (next_index < operands_.size()) {
operands_[next_index - removed_count] = operands_[next_index];
++next_index;
}
CHECK_EQ(removed_count, ascending_indices.size());
operands_.resize(operands_.size() - removed_count);
}
void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_map_, user)) {
user_map_.emplace(user, users_.size());
users_.push_back(user);
}
}
int64_t HloInstruction::UserId(HloInstruction* user) {
auto result = user_map_.find(user);
CHECK(result != user_map_.end());
return result->second;
}
bool HloInstruction::HasConstantOperand() const {
for (const HloInstruction* operand : operands_) {
if (operand->IsConstant()) {
return true;
}
}
return false;
}
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
// operands.
case HloOpcode::kAbs:
case HloOpcode::kAllGatherDone:
case HloOpcode::kAllReduceDone:
case HloOpcode::kAtan2:
case HloOpcode::kAdd:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kOptimizationBarrier:
case HloOpcode::kPartitionId:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kReplicaId:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
return true;
// This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
case HloOpcode::kAddDependency:
return false;
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kConditional:
for (int j = 0; j < branch_count(); ++j) {
if (!eq_computations(branch_computation(j),
other.branch_computation(j))) {
return false;
}
}
return true;
case HloOpcode::kWhile:
return (eq_computations(while_body(), other.while_body()) &&
eq_computations(while_condition(), other.while_condition()));
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kAsyncStart:
case HloOpcode::kAsyncUpdate:
case HloOpcode::kAsyncDone:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kIota:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kAllGather:
case HloOpcode::kAllGatherStart:
case HloOpcode::kAllReduce:
case HloOpcode::kReduceScatter:
case HloOpcode::kAllReduceStart:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCollectivePermuteStart:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kDot:
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kSetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
return false;
}
void HloInstruction::RemoveUser(HloInstruction* user) {
auto map_it = user_map_.find(user);
CHECK(map_it != user_map_.end());
const int64_t index = map_it->second;
CHECK_EQ(users_[index], user);
// Move the last user into the position of the removed user.
users_[index] = users_.back();
user_map_[users_.back()] = index;
// Remove the user from the map and drop the last slot from the vector what
// have been moved to the position of the original user.
user_map_.erase(map_it);
users_.pop_back();
}
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
return ReplaceUseWithDifferentShape(user, new_producer);
}
Status HloInstruction::ReplaceUseWithDifferentShape(
HloInstruction* user, HloInstruction* new_producer) {
VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
<< " with " << new_producer->name();
RemoveUser(user);
TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
// Custom fusions may not be able to handle deduplicated operands.
if (user->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(
Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
}
return OkStatus();
}
Status HloInstruction::ReplaceUseWith(HloInstruction* user, int operand_number,
HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
return ReplaceUseWithDifferentShape(user, operand_number, new_producer);
}
Status HloInstruction::ReplaceUseWithDifferentShape(
HloInstruction* user, int operand_number, HloInstruction* new_producer) {
VLOG(3) << "Replacing operand " << operand_number << " of " << name()
<< " in " << user->name() << " with " << new_producer->name();
if (absl::c_count(user->operands_, this) == 1) {
RemoveUser(user);
}
TF_RET_CHECK(user->operand(operand_number) == this)
<< "Expected operand " << operand_number << " of " << user->ToString()
<< " to be equal to " << ToString();
user->operands_[operand_number] = new_producer;
new_producer->AddUser(user);
return OkStatus();
}
Status HloInstruction::ReplaceOperandWith(int64_t operand_num,
HloInstruction* new_operand) {
auto old_operand = operand(operand_num);
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
<< old_operand->shape() << " is not compatible with "
<< new_operand->shape();
return ReplaceOperandWithDifferentShape(operand_num, new_operand);
}
Status HloInstruction::ReplaceOperandWithDifferentShape(
int64_t operand_num, HloInstruction* new_operand) {
TF_RET_CHECK(operand_num >= 0);
TF_RET_CHECK(operand_num < operand_count());
HloInstruction* old_operand = mutable_operand(operand_num);
if (old_operand == new_operand) {
return OkStatus();
}
operands_[operand_num] = new_operand;
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
<< new_operand->name() << ", was " << old_operand->name();
if (!absl::c_linear_search(operands_, old_operand)) {
old_operand->RemoveUser(this);
}
new_operand->AddUser(this);
return OkStatus();
}
Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< shape() << " is not compatible with " << new_producer->shape();
return ReplaceAllUsesWithDifferentShape(users, new_producer);
}
Status HloInstruction::ReplaceAllUsesWithDifferentShape(
absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
}
if (parent_ && parent_->root_instruction() == this) {
parent_->set_root_instruction(new_producer,
/*accept_different_shape=*/true);
}
return OkStatus();
}
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< shape() << " is not compatible with " << new_producer->shape();
return ReplaceAllUsesWithDifferentShape(new_producer);
}
Status HloInstruction::ReplaceAllUsesWithDifferentShape(
HloInstruction* new_producer) {
bool new_producer_is_user = false;
for (HloInstruction* user : users()) {
if (user == new_producer) {
// It's possible that new_producer is a user of this instruction as might
// be the case when replacing an instruction with a kCopy of itself. In
// this case, don't do the replacement to avoid creating a cycle in the
// graph. new_producer remains the only user of this instruction.
new_producer_is_user = true;
} else {
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
if (user->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(
Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
}
}
}
users_.clear();
user_map_.clear();
if (new_producer_is_user) {
AddUser(new_producer);
}
if (parent_ && parent_->root_instruction() == this) {
parent_->set_root_instruction(new_producer,
/*accept_different_shape=*/true);
}
return OkStatus();
}
bool HloInstruction::IsEffectiveBitcast() const {
return opcode_ == HloOpcode::kBitcast ||
(opcode_ == HloOpcode::kTranspose &&
ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
dimensions()));
}
HloComputation* HloInstruction::to_apply() const {
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kReduceScatter:
case HloOpcode::kAllReduceStart:
case HloOpcode::kScatter:
case HloOpcode::kSort:
case HloOpcode::kCustomCall:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
LOG(FATAL) << "Invalid opcode for to_apply(): "
<< HloOpcodeString(opcode());
}
}
void HloInstruction::set_to_apply(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kScatter:
case HloOpcode::kSort:
case HloOpcode::kCustomCall:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
default:
LOG(FATAL) << "Invalid opcode for to_apply(): "
<< HloOpcodeString(opcode());
}
}
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kConditionComputationIndex];
}
HloComputation* HloInstruction::while_body() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kBodyComputationIndex];
}
void HloInstruction::set_while_condition(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
called_computations_[kConditionComputationIndex] = computation;
}
void HloInstruction::set_while_body(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kWhile, opcode_);
called_computations_[kBodyComputationIndex] = computation;
}
HloInstruction* HloInstruction::while_init() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return operands_[0];
}
HloComputation* HloInstruction::true_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_EQ(PRED, operand(0)->shape().element_type());
return called_computations_[kTrueComputationIndex];
}
HloComputation* HloInstruction::false_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_EQ(PRED, operand(0)->shape().element_type());
return called_computations_[kFalseComputationIndex];
}
const std::vector<HloComputation*>& HloInstruction::branch_computations()
const {
CHECK(HloOpcode::kConditional == opcode_);
return called_computations_;
}
int HloInstruction::branch_count() const {
CHECK(HloOpcode::kConditional == opcode_);
return called_computations_.size();
}
HloComputation* HloInstruction::branch_computation(int b) const {
CHECK(HloOpcode::kConditional == opcode_);
CHECK_GE(b, 0);
CHECK_LT(b, called_computations_.size());
return called_computations_[b];
}
void HloInstruction::set_branch_computation(int b,
HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
CHECK_EQ(HloOpcode::kConditional, opcode_);
called_computations_[b] = computation;
}
std::string HloInstruction::SignatureString() const {
std::string operands =
StrJoin(operands_, ", ", [](std::string* out, HloInstruction* operand) {
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
}
std::string PrintName(const std::string& name, bool print_ids) {
if (print_ids) {
return name;
} else {
auto dot_position = name.find_first_of('.');
return name.substr(0, dot_position);
}
}
namespace {
using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
std::string PrintNameInternal(const std::string& name,
const HloPrintOptions& options) {
return StrCat(options.print_percent() ? "%" : "",
PrintName(name, options.print_ids()));
}
void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
// This set contains HloInstructions from the top of `DFSStack` that might
// belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
// `subgraph` := {child,...,top}.
absl::flat_hash_set<const HloInstruction*> subgraph;
while (!dfs_stack->empty() && dfs_stack->back().second != child) {
subgraph.insert(dfs_stack->back().second);
dfs_stack->pop_back();
}
// Start dfs at `child` and find a cycle with all nodes in `subgraph`.
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 16> dfs;
dfs.push_back(child);
while (!dfs.empty()) {
bool found_next_instr = false;
for (const auto& user : dfs.back()->users()) {
if (user == child) {
dfs.push_back(child);
LOG(INFO) << "\n\nDirected cycle:\n "
<< absl::StrJoin(
dfs, "\n ",
[](std::string* out, const HloInstruction* instr) {
out->append(instr->name());
});
return;
}
if (!subgraph.contains(user) || visited.contains(user)) {
continue;
}
visited.insert(user);
dfs.push_back(user);
found_next_instr = true;
}
if (!found_next_instr) {
dfs.pop_back();
}
}
}
} // namespace
std::string HloInstruction::ToString(const HloPrintOptions& options) const {
CanonicalNameMap new_map;
return ToStringWithCanonicalNameMap(options, &new_map);
}
bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
switch (opcode) {
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRoundNearestEven:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
return true;
// Binary elementwise operations, the same as in IsElementwiseBinary().
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
return true;
// Ternary elementwise operations.
case HloOpcode::kSelect:
case HloOpcode::kClamp:
return true;
default:
return false;
}
}
bool HloInstruction::IsElementwiseImpl(
const std::optional<int64_t>& operand_idx) const {
if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
return operand_idx.has_value() && operand_idx.value() == 0;
}
if (opcode_ == HloOpcode::kBitcastConvert &&
primitive_util::BitWidth(shape_.element_type()) !=
primitive_util::BitWidth(operands_[0]->shape().element_type())) {
return false;
}
return IsOpElementwise(opcode_);
}
bool HloInstruction::IsCrossModuleAllReduce() const {
return opcode() == HloOpcode::kAllReduce && channel_id();
}
bool HloInstruction::IsCrossReplicaAllReduce() const {
return opcode() == HloOpcode::kAllReduce && !channel_id();
}
std::string HloInstruction::ToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
std::string result = "";
// Logic to print the instruction name (e.g. "%foo = ").
if (options.canonicalize_instruction_names()) {
if (options.is_in_nested_computation()) {
// If we are canonicalizing instruction names and this is a top-level
// HloInstruction::ToString() call, don't print an instruction name.
DCHECK(!options.print_percent()); // no need to call PrintNameInternal
StrAppend(&result, canonical_name_map->LookupOrInsert(name()), " = ");
}
} else {
StrAppend(&result, PrintNameInternal(name(), options), " = ");
}
if (options.print_result_shape()) {
// Print shape.
if (options.include_layout_in_shapes()) {
StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
} else {
StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
}
}
// Print opcode, operand(s).
if (options.syntax_sugar_async_ops() && HloOpcodeIsAsync(opcode())) {
std::string suffix = [&]() {
switch (opcode()) {
case HloOpcode::kAsyncStart:
return "-start";
case HloOpcode::kAsyncUpdate:
return "-update";
default:
CHECK(opcode() == HloOpcode::kAsyncDone)
<< "Unexpected async opcode: " << HloOpcodeString(opcode());
return "-done";
}
}();
StrAppend(&result, HloOpcodeString(async_wrapped_opcode()), suffix);
} else {
StrAppend(&result, HloOpcodeString(opcode()));
}
StrAppend(&result, "(",
OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
")");
// Print additional attributes. If an instruction contains a subcomputation,
// the subcomputation is also printed here.
for (const std::string& extra : ExtraAttributesToString(options)) {
StrAppend(&result, ", ", extra);
}
if (options.print_metadata() &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
if (options.print_backend_config() && !backend_config_.empty()) {
StrAppend(&result, ", backend_config=\"",
CEscape(backend_config_.GetRawString()), "\"");
}
return result;
}
std::string HloInstruction::OperandsToString(
const HloPrintOptions& options) const {
CanonicalNameMap new_map;
return OperandsToStringWithCanonicalNameMap(options, &new_map);
}
std::string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
std::string operands;
absl::Span<HloInstruction* const> slice(operands_);
const int64_t kMaxOperandsToShowIfCompact = 4;
if (options.compact_operands() &&
slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
for (int64_t i = 0; i < slice.size(); ++i) {
HloInstruction* operand = slice[i];
if (i != 0) {
StrAppend(&operands, ", ");
if (options.print_operand_index_annotation_interval() != 0 &&
i % options.print_operand_index_annotation_interval() == 0) {
StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i));
}
}
// If operand is already been deleted, put `null` to the string output.
if (operand == nullptr) {
StrAppend(&operands, "null ");
continue;
}
std::vector<std::string> str;
if (options.print_operand_shape()) {
if (options.include_layout_in_shapes()) {
str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
} else {
str.push_back(ShapeUtil::HumanString(operand->shape()));
}
}
if (options.canonicalize_instruction_names()) {
if (options.is_in_nested_computation()) {
// In a top-level HloInstruction::ToString() call, the operand name is
// not part of the canonical string.
DCHECK(!options.print_percent()); // no need to call PrintNameInternal
str.push_back(canonical_name_map->LookupOrInsert(operand->name()));
}
} else if (options.print_operand_names()) {
str.push_back(PrintNameInternal(operand->name(), options));
}
StrAppend(&operands, StrJoin(str, " "));
}
const int64_t remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
StrAppend(&operands, ", ...(+", remaining, ")");
}
return operands;
}
namespace {
bool IsSequentialCall(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kWhile:
return true;
default:
return false;
}
}
} // namespace
std::vector<std::string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<std::string> extra = options.print_extra_attributes()
? ExtraAttributesToStringImpl(options)
: std::vector<std::string>();
const auto subcomputation_mode = options.print_subcomputation_mode();
if (subcomputation_mode ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
extra.push_back(StrCat(
"condition=", PrintNameInternal(while_condition()->name(), options)));
extra.push_back(
StrCat("body=", PrintNameInternal(while_body()->name(), options)));
} else if (opcode() == HloOpcode::kSelectAndScatter) {
extra.push_back(
StrCat("select=", PrintNameInternal(select()->name(), options)));
extra.push_back(
StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
} else if (opcode() == HloOpcode::kConditional) {
if (operand(0)->shape().element_type() == PRED) {
extra.push_back(
StrCat("true_computation=",
PrintNameInternal(true_computation()->name(), options)));
extra.push_back(
StrCat("false_computation=",
PrintNameInternal(false_computation()->name(), options)));
} else {
extra.push_back(StrCat(
"branch_computations={",
StrJoin(branch_computations(), ", ",
[&](std::string* out, const HloComputation* computation) {
StrAppend(
out, PrintNameInternal(computation->name(), options));
}),
"}"));
}
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
opcode() == HloOpcode::kAllReduce ||
opcode() == HloOpcode::kReduceScatter ||
opcode() == HloOpcode::kAllReduceStart ||
opcode() == HloOpcode::kScatter ||
opcode() == HloOpcode::kSort) {
extra.push_back(
StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
} else if (opcode() == HloOpcode::kCustomCall) {
if (!called_computations().empty()) {
extra.push_back(StrCat(
"called_computations={",
StrJoin(called_computations(), ", ",
[&](std::string* out, const HloComputation* computation) {
StrAppend(
out, PrintNameInternal(computation->name(), options));
}),
"}"));
}
} else if (HloOpcodeIsAsync(opcode())) {
if (!options.syntax_sugar_async_ops()) {
extra.push_back(StrCat(
"calls=",
PrintNameInternal(async_wrapped_computation()->name(), options)));
}
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=",
StrJoin(called_computations(), ", ",
[&](std::string* out, const HloComputation* computation) {
StrAppend(out,
PrintNameInternal(computation->name(), options));
})));
}
} else if ((subcomputation_mode ==
HloPrintOptions::PrintSubcomputationMode::kFullBodies) ||
(subcomputation_mode == HloPrintOptions::PrintSubcomputationMode::
kNonSequentialBodies &&
!IsSequentialCall(opcode()))) {
HloPrintOptions new_options = options;
new_options.set_is_in_nested_computation(true);
switch (opcode()) {
case HloOpcode::kWhile:
extra.push_back(
StrCat("condition=\n", while_condition()->ToString(new_options)));
extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
break;
case HloOpcode::kSelectAndScatter:
extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
break;
case HloOpcode::kConditional:
if (operand(0)->shape().element_type() == PRED) {
extra.push_back(StrCat("true_computation=\n",
true_computation()->ToString(new_options)));
extra.push_back(StrCat("false_computation=\n",
false_computation()->ToString(new_options)));
} else {
extra.push_back(StrCat(
"branch_computations={\n",
StrJoin(branch_computations(), ",\n",
[&](std::string* out, const HloComputation* computation) {
StrAppend(out, computation->ToString(new_options));
}),
"\n}"));
}
break;
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kAllReduceStart:
case HloOpcode::kScatter:
case HloOpcode::kSort:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
default:
if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=\n",
StrJoin(called_computations(), ", ",
[&](std::string* out, const HloComputation* computation) {
StrAppend(out, computation->ToString(new_options));
})));
}
break;
}
}
if (has_sharding()) {
extra.push_back(
StrCat("sharding=", sharding().ToString(options.print_metadata())));
}
if (!frontend_attributes_.map().empty()) {
extra.push_back(StrCat("frontend_attributes=",
FrontendAttributesToString(frontend_attributes_)));
}
if (!outer_dimension_partitions_.empty()) {
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
StrJoin(outer_dimension_partitions_, ",")));
}
if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
StrJoin(control_predecessors_, ", ",
[&](std::string* out, HloInstruction* pre) {
StrAppend(out, PrintNameInternal(
pre->name(), options));
}),
"}"));
}
return extra;
}
std::string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
StrJoin(operands_, ", ",
[](std::string* out, HloInstruction* operand) {
StrAppend(out, "%", operand->name());
}),
")");
}
HloInstructionProto HloInstruction::ToProto() const {
HloInstructionProto proto;
CHECK(unique_id_ != -1)
<< "This instruction does not have a valid id. Please make sure the "
"instruction is inside a module before dumping it.";
proto.set_id(unique_id_);
proto.set_name(name_);
proto.set_opcode(HloOpcodeString(opcode_));
*proto.mutable_shape() = shape_.ToProto();
for (const HloInstruction* operand : operands_) {
proto.add_operand_ids(operand->unique_id());
}
for (const HloInstruction* control : control_predecessors_) {
proto.add_control_predecessor_ids(control->unique_id());
}
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_.GetRawString());
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
if (!outer_dimension_partitions_.empty()) {
for (const auto& idx : outer_dimension_partitions_) {
proto.mutable_outer_dimension_partitions()->Add(idx);
}
}
*proto.mutable_frontend_attributes() = frontend_attributes_;
return proto;
}
std::string HloInstruction::ToCategory() const {
if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
opcode() == HloOpcode::kReshape ||
opcode() == HloOpcode::kDynamicReshape) {
return "data formatting";
}
if (IsElementwise()) {
return "non-fusion elementwise";
}
return HloOpcodeString(opcode());
}
bool HloInstruction::IsFused() const {
return parent_ != nullptr && parent_->IsFusionComputation();
}
bool HloInstruction::IsCustomCall(absl::string_view target) const {
return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
}
bool HloInstruction::IsInputFusion() const {
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
}
bool HloInstruction::IsLoopFusion() const {
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
}
bool HloInstruction::IsOutputFusion() const {
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
}
bool HloInstruction::IsCustomFusion() const {
return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
}
bool HloInstruction::IsFusible() const {
// Some kinds of instructions don't make sense to fuse.
switch (opcode_) {
case HloOpcode::kDomain:
case HloOpcode::kParameter:
case HloOpcode::kWhile:
case HloOpcode::kConditional:
case HloOpcode::kCall:
return false;
// Fusions are always fusible.
case HloOpcode::kFusion:
// Side effecting reduce and reduce window would be invalid HLO.
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
return true;
case HloOpcode::kRng:
return user_count() <= 1;
// Side effecting instructions cannot be fused.
default:
return !HasSideEffect();
}
}
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
: unique_id_(-1),
opcode_(opcode),
shape_(shape),
name_(HloOpcodeString(opcode)),
marked_as_dead_(false) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
}
template <typename HloInstructionPtr>
Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
switch (opcode_) {
case HloOpcode::kAbs:
return visitor->HandleAbs(this);
case HloOpcode::kAtan2:
return visitor->HandleAtan2(this);
case HloOpcode::kRoundNearestAfz:
return visitor->HandleRound(this);
case HloOpcode::kRoundNearestEven:
return visitor->HandleRoundNearestEven(this);
case HloOpcode::kBatchNormTraining:
return visitor->HandleBatchNormTraining(this);
case HloOpcode::kBatchNormInference:
return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
case HloOpcode::kLogistic:
return visitor->HandleLogistic(this);
case HloOpcode::kSign:
return visitor->HandleSign(this);
case HloOpcode::kConstant:
return visitor->HandleConstant(this);
case HloOpcode::kGetTupleElement:
return visitor->HandleGetTupleElement(this);
case HloOpcode::kParameter:
return visitor->HandleParameter(this);
case HloOpcode::kCompare:
return visitor->HandleCompare(this);
case HloOpcode::kComplex:
return visitor->HandleComplex(this);
case HloOpcode::kAdd:
return visitor->HandleAdd(this);
case HloOpcode::kDivide:
return visitor->HandleDivide(this);
case HloOpcode::kSubtract:
return visitor->HandleSubtract(this);
case HloOpcode::kMaximum:
return visitor->HandleMaximum(this);
case HloOpcode::kMinimum:
return visitor->HandleMinimum(this);
case HloOpcode::kAnd:
return visitor->HandleAnd(this);
case HloOpcode::kOr:
return visitor->HandleOr(this);
case HloOpcode::kXor:
return visitor->HandleXor(this);
case HloOpcode::kShiftLeft:
return visitor->HandleShiftLeft(this);
case HloOpcode::kShiftRightArithmetic:
return visitor->HandleShiftRightArithmetic(this);
case HloOpcode::kShiftRightLogical:
return visitor->HandleShiftRightLogical(this);
case HloOpcode::kConcatenate:
return visitor->HandleConcatenate(this);
case HloOpcode::kConvert:
return visitor->HandleConvert(this);
case HloOpcode::kBitcastConvert:
return visitor->HandleBitcastConvert(this);
case HloOpcode::kCopy:
return visitor->HandleCopy(this);
case HloOpcode::kMultiply:
return visitor->HandleMultiply(this);
case HloOpcode::kDot:
return visitor->HandleDot(this);
case HloOpcode::kPower:
return visitor->HandlePower(this);
case HloOpcode::kRemainder:
return visitor->HandleRemainder(this);
case HloOpcode::kSelect:
return visitor->HandleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
case HloOpcode::kFft:
return visitor->HandleFft(this);
case HloOpcode::kAllGather:
return visitor->HandleAllGather(this);
case HloOpcode::kAllGatherStart:
return visitor->HandleAllGatherStart(this);
case HloOpcode::kAllGatherDone:
return visitor->HandleAllGatherDone(this);
case HloOpcode::kAllReduce:
return visitor->HandleAllReduce(this);
case HloOpcode::kReduceScatter:
return visitor->HandleReduceScatter(this);
case HloOpcode::kAllReduceStart:
return visitor->HandleAllReduceStart(this);
case HloOpcode::kAllReduceDone:
return visitor->HandleAllReduceDone(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
case HloOpcode::kCollectivePermute:
return visitor->HandleCollectivePermute(this);
case HloOpcode::kCollectivePermuteStart:
return visitor->HandleCollectivePermuteStart(this);
case HloOpcode::kCollectivePermuteDone:
return visitor->HandleCollectivePermuteDone(this);
case HloOpcode::kReplicaId:
return visitor->HandleReplicaId(this);
case HloOpcode::kPartitionId:
return visitor->HandlePartitionId(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
return visitor->HandleMap(this);
case HloOpcode::kClamp:
return visitor->HandleClamp(this);
case HloOpcode::kReduce:
return visitor->HandleReduce(this);
case HloOpcode::kReduceWindow:
return visitor->HandleReduceWindow(this);
case HloOpcode::kSelectAndScatter:
return visitor->HandleSelectAndScatter(this);
case HloOpcode::kNegate:
return visitor->HandleNegate(this);
case HloOpcode::kExp:
return visitor->HandleExp(this);
case HloOpcode::kExpm1:
return visitor->HandleExpm1(this);
case HloOpcode::kFloor:
return visitor->HandleFloor(this);
case HloOpcode::kCeil:
return visitor->HandleCeil(this);
case HloOpcode::kClz:
return visitor->HandleClz(this);
case HloOpcode::kLog:
return visitor->HandleLog(this);
case HloOpcode::kLog1p:
return visitor->HandleLog1p(this);
case HloOpcode::kTanh:
return visitor->HandleTanh(this);
case HloOpcode::kCos:
return visitor->HandleCos(this);
case HloOpcode::kSin:
return visitor->HandleSin(this);
case HloOpcode::kSqrt:
return visitor->HandleSqrt(this);
case HloOpcode::kCbrt:
return visitor->HandleCbrt(this);
case HloOpcode::kRsqrt:
return visitor->HandleRsqrt(this);
case HloOpcode::kReal:
return visitor->HandleReal(this);
case HloOpcode::kImag:
return visitor->HandleImag(this);
case HloOpcode::kIsFinite:
return visitor->HandleIsFinite(this);
case HloOpcode::kNot:
return visitor->HandleNot(this);
case HloOpcode::kPopulationCount:
return visitor->HandlePopulationCount(this);
case HloOpcode::kBitcast:
return visitor->HandleBitcast(this);
case HloOpcode::kBroadcast:
return visitor->HandleBroadcast(this);
case HloOpcode::kPad:
return visitor->HandlePad(this);
case HloOpcode::kReshape:
return visitor->HandleReshape(this);
case HloOpcode::kDynamicReshape:
return visitor->HandleDynamicReshape(this);
case HloOpcode::kTranspose:
return visitor->HandleTranspose(this);
case HloOpcode::kReverse:
return visitor->HandleReverse(this);
case HloOpcode::kReducePrecision:
return visitor->HandleReducePrecision(this);
case HloOpcode::kSlice:
return visitor->HandleSlice(this);
case HloOpcode::kDynamicSlice:
return visitor->HandleDynamicSlice(this);
case HloOpcode::kDynamicUpdateSlice:
return visitor->HandleDynamicUpdateSlice(this);
case HloOpcode::kSort:
return visitor->HandleSort(this);
case HloOpcode::kInfeed:
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kRngBitGenerator:
return visitor->HandleRngBitGenerator(this);
case HloOpcode::kRngGetAndUpdateState:
return visitor->HandleRngGetAndUpdateState(this);
case HloOpcode::kWhile:
return visitor->HandleWhile(this);
case HloOpcode::kFusion:
return visitor->HandleFusion(this);
case HloOpcode::kCall:
return visitor->HandleCall(this);
case HloOpcode::kConditional:
return visitor->HandleConditional(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
case HloOpcode::kAsyncStart:
return visitor->HandleAsyncStart(this);
case HloOpcode::kAsyncUpdate:
return visitor->HandleAsyncUpdate(this);
case HloOpcode::kAsyncDone:
return visitor->HandleAsyncDone(this);
case HloOpcode::kCopyStart:
return visitor->HandleCopyStart(this);
case HloOpcode::kCopyDone:
return visitor->HandleCopyDone(this);
case HloOpcode::kRecv:
return visitor->HandleRecv(this);
case HloOpcode::kRecvDone:
return visitor->HandleRecvDone(this);
case HloOpcode::kSend:
return visitor->HandleSend(this);
case HloOpcode::kSendDone:
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
case HloOpcode::kScatter:
return visitor->HandleScatter(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
return visitor->HandleAfterAll(this);
case HloOpcode::kAddDependency:
return visitor->HandleAddDependency(this);
case HloOpcode::kIota:
return visitor->HandleIota(this);
case HloOpcode::kGetDimensionSize:
return visitor->HandleGetDimensionSize(this);
case HloOpcode::kSetDimensionSize:
return visitor->HandleSetDimensionSize(this);
case HloOpcode::kTriangularSolve:
return visitor->HandleTriangularSolve(this);
case HloOpcode::kCholesky:
return visitor->HandleCholesky(this);
case HloOpcode::kOptimizationBarrier:
return visitor->HandleOptimizationBarrier(this);
}
return InternalError(
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
"please file a bug for XLA.",
HloOpcodeString(opcode_));
}
// Explicit instantiations.
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
template <typename Visitor>
inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
HloInstruction* child) {
CHECK(child != nullptr);
const int id = child->unique_id();
CHECK_GE(id, 0) << "instruction may not have a parent computation";
switch (visitor->GetVisitState(id)) {
case Visitor::kVisiting:
return false;
case Visitor::kVisited:
// Nothing to do
return true;
case Visitor::kNotVisited:
dfs_stack->push_back(std::make_pair(id, child));
return true;
}
}
using InternalCompareFunction =
std::function<bool(std::pair<int, const HloInstruction*>,
std::pair<int, const HloInstruction*>)>;
template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
visitor->ReserveVisitStates(root->parent()->instruction_count());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
// We need to keep track of both the id and the instruction because
// instructions can get deleted while they are on the stack, so we
// can't always use the (potentially dead) instruction object to grab
// its id.
DFSStack dfs_stack;
dfs_stack.emplace_back(root->unique_id(), root);
do {
DCHECK(!dfs_stack.empty());
int current_id = dfs_stack.back().first;
HloInstruction* current_node = dfs_stack.back().second;
CHECK_GE(current_id, 0) << current_id << ": " << current_node
<< ": instruction may not have parent computation";
typename Visitor::VisitState visit_state =
visitor->GetVisitState(current_id);
if (visit_state == Visitor::kVisited) {
dfs_stack.pop_back();
VLOG(3) << "Not visiting HLO (id = " << current_id
<< ") as it was already visited.";
continue;
}
if (visit_state == Visitor::kVisiting) {
dfs_stack.pop_back();
TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
VLOG(2) << "Visiting HLO %" << current_node->name();
TF_RETURN_IF_ERROR(current_node->Visit(visitor));
visitor->SetVisitState(current_id, Visitor::kVisited);
TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
continue;
}
visitor->SetVisitState(current_id, Visitor::kVisiting);
const size_t old_dfs_stack_size = dfs_stack.size();
for (HloInstruction* child : current_node->operands()) {
if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
PrintCycle(child, &dfs_stack);
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());
}
}
if (!ignore_control_predecessors) {
for (HloInstruction* child : current_node->control_predecessors()) {
if (!ABSL_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
PrintCycle(child, &dfs_stack);
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
current_node->ToString());
}
}
}
if (operand_order != nullptr) {
std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
*operand_order);
}
// This makes the traversal order the same as what you'd expect
// out of a recursive algorithm.
std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
} while (!dfs_stack.empty());
return OkStatus();
}
template <typename HloInstructionPtr>
Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
bool call_finish_visit,
bool ignore_control_predecessors) {
VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
TF_RETURN_IF_ERROR(
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
}
return OkStatus();
}
// Explicit instantiations.
template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
Status HloInstruction::AcceptWithOperandOrder(
DfsHloVisitor* visitor, const CompareFunction& operand_order,
bool call_finish_visit) {
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
InternalCompareFunction func = [&operand_order](
std::pair<int, const HloInstruction*> a,
std::pair<int, const HloInstruction*> b) {
// Call the client's comparison function on the actual HloInstruction*
// objects (ignoring the internal ids we also have in our stack entries)
return operand_order(a.second, b.second);
};
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
/*ignore_control_predecessors=*/false));
if (call_finish_visit) {
VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
}
VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
return OkStatus();
}
const Shape& HloInstruction::shape() const { return shape_; }
absl::InlinedVector<int64_t, 4> HloInstruction::OperandIndices(
const HloInstruction* operand) const {
absl::InlinedVector<int64_t, 4> result;
for (int64_t i = 0; i < operand_count(); ++i) {
if (this->operand(i) == operand) {
result.push_back(i);
}
}
return result;
}
bool HloInstruction::IsElementwiseBinary() const {
return IsElementwise() && operand_count() == 2;
}
bool HloInstruction::IsElementwise() const {
return IsElementwiseImpl(std::nullopt);
}
bool HloInstruction::IsElementwiseOnOperand(int64_t operand_idx) const {
return IsElementwiseImpl(operand_idx);
}
namespace {
// Indicates how an instruction uses a value (such as an operand).
//
// Does it (a) not use it, (b) use it, or (c) use it multiple times?
enum class UseKind { kReuse = 0, kUse = 1, kNoUse = 2 };
// A helper class for memoized, recursive computation of HloOpcode::kFusion
// in HloInstruction::OperandElementUse below.
class FusionReusesParamElements {
public:
static UseKind Compute(int64_t i, const HloInstruction& hlo) {
absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
return ComputeInternal(i, hlo, &memoization_cache);
}
private:
static UseKind ComputeInternal(
int64_t outer_param_num, const HloInstruction& hlo,
absl::flat_hash_map<const HloInstruction*, UseKind>* cache);
};
} // namespace
// Returns how this instruction uses elements of its operand at operand_num.
static UseKind OperandElementUse(const HloInstruction& instr,
int64_t operand_num) {
switch (instr.opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
case HloOpcode::kTranspose:
case HloOpcode::kGather:
return UseKind::kUse;
case HloOpcode::kPad:
// Pad reuses the padding value but not the padded array elements.
return operand_num > 0 ? UseKind::kReuse : UseKind::kUse;
case HloOpcode::kReduce:
// Reduce reuses the init values but not the operand array elements.
return operand_num >= Cast<HloReduceInstruction>(&instr)->input_count()
? UseKind::kReuse
: UseKind::kUse;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(operand_num,
*instr.fused_expression_root());
case HloOpcode::kDot:
// Matrix-vector dots do not reuse the matrix operand.
if (instr.shape().dimensions_size() <= 1) {
if ((operand_num == 0 && instr.operand(1)->shape().rank() <= 1) ||
(operand_num == 1 && instr.operand(0)->shape().rank() <= 1)) {
return UseKind::kUse;
}
}
return UseKind::kReuse;
case HloOpcode::kDynamicUpdateSlice:
// Dynamic-update-slice reuses only start_indices.
if (operand_num == 0 || operand_num == 1) {
return UseKind::kUse;
}
return UseKind::kReuse;
default:
return instr.IsElementwise() ? UseKind::kUse : UseKind::kReuse;
}
}
UseKind FusionReusesParamElements::ComputeInternal(
int64_t outer_param_num, const HloInstruction& hlo,
absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
if (hlo_param->parameter_number() == outer_param_num) {
return UseKind::kUse;
}
}
auto p = cache->emplace(&hlo, UseKind::kNoUse);
auto value_it = p.first;
const bool key_is_new = p.second;
if (!key_is_new) {
return value_it->second;
}
// Our dataflow graph has no loops, so we don't need the fixed point
// computation.
for (int64_t operand_num = 0; operand_num < hlo.operands().size();
++operand_num) {
UseKind old_val = value_it->second;
// Compute updated value.
UseKind new_val = [&] {
// How does the HLO use this operand.
UseKind hlo_use = OperandElementUse(hlo, operand_num);
// If the HLO does not use the outer operand, return previous value.
if (hlo_use == UseKind::kNoUse) {
return old_val;
}
UseKind operand_use =
ComputeInternal(outer_param_num, *hlo.operand(operand_num), cache);
// If the operand does not use the outer operand, return the previous
// value.
if (operand_use == UseKind::kNoUse) {
return old_val;
}
// Meet operator on a lattice:
//
// kReuse < kUse < kNoUse.
return std::min({old_val, hlo_use, operand_use});
}();
value_it = cache->find(&hlo);
value_it->second = new_val;
// Fold() minimizes the UseKind value. If it is already minimum, we do not
// have to check all the remaining operands.
if (new_val == UseKind::kReuse) {
break;
}
}
return value_it->second;
}
bool HloInstruction::ReusesOperandElements(int64_t i) const {
return OperandElementUse(*this, i) == UseKind::kReuse;
}
std::optional<ShapeUtil::ShapeEqualityDescriptor>
HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
if (HloOpcode::kReshape != opcode_) {
return std::nullopt;
}
return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
shape_);
}
std::string ToString(HloInstruction::FusionKind kind) {
switch (kind) {
case HloInstruction::FusionKind::kLoop:
return "kLoop";
case HloInstruction::FusionKind::kInput:
return "kInput";
case HloInstruction::FusionKind::kOutput:
return "kOutput";
case HloInstruction::FusionKind::kCustom:
return "kCustom";
}
}
StatusOr<HloInstruction::FusionKind> StringToFusionKind(
const std::string& kind_name) {
if (kind_name == "kLoop") {
return HloInstruction::FusionKind::kLoop;
}
if (kind_name == "kInput") {
return HloInstruction::FusionKind::kInput;
}
if (kind_name == "kOutput") {
return HloInstruction::FusionKind::kOutput;
}
if (kind_name == "kCustom") {
return HloInstruction::FusionKind::kCustom;
}
return InvalidArgument("Unknown fusion kind: %s", kind_name);
}
std::string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes) {
std::vector<std::pair<std::string, std::string>> sorted_attributes(
frontend_attributes.map().begin(), frontend_attributes.map().end());
absl::c_sort(sorted_attributes);
// Frontend attribute is a comma-separated list of attribute="value" pairs,
// e.g., frontend_attributes={name="value_a",type="int32_t"}.
const auto formatter = [](std::string* out,
const std::pair<std::string, std::string>& item) {
absl::StrAppend(out, item.first, "=\"", item.second, "\"");
};
return absl::StrFormat("{%s}",
absl::StrJoin(sorted_attributes, ",", formatter));
}
std::string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding =
absl::c_any_of(padding.dimensions(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
return StrJoin(
padding.dimensions(), "x",
[&](std::string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
});
}
std::string RandomDistributionToString(const RandomDistribution& distribution) {
return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
}
std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
}
std::string PrecisionToString(const PrecisionConfig::Precision& precision) {
return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
}
static std::string CustomCallScheduleToString(
const CustomCallSchedule& schedule) {
return absl::AsciiStrToLower(CustomCallSchedule_Name(schedule));
}
static std::string CustomCallApiVersionToString(
const CustomCallApiVersion& schedule) {
return absl::AsciiStrToLower(CustomCallApiVersion_Name(schedule));
}
std::string DotDimensionNumbersToString(const DotDimensionNumbers& dnums) {
std::vector<std::string> result;
if (!dnums.lhs_batch_dimensions().empty()) {
result.push_back(StrCat("lhs_batch_dims={",
StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("lhs_contracting_dims={",
StrJoin(dnums.lhs_contracting_dimensions(), ","),
"}"));
if (!dnums.rhs_batch_dimensions().empty()) {
result.push_back(StrCat("rhs_batch_dims={",
StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
}
result.push_back(StrCat("rhs_contracting_dims={",
StrJoin(dnums.rhs_contracting_dimensions(), ","),
"}"));
return StrJoin(result, ", ");
}
std::string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums) {
auto len_required = [](int64_t a, int64_t b, absl::Span<const int64_t> cs) {
return std::max({a, b, cs.empty() ? 0 : *absl::c_max_element(cs)}) + 1;
};
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
std::vector<std::string> lhs_dims(
len_required(dnums.input_batch_dimension(),
dnums.input_feature_dimension(),
dnums.input_spatial_dimensions()),
"?");
lhs_dims[dnums.input_batch_dimension()] = 'b';
lhs_dims[dnums.input_feature_dimension()] = 'f';
for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
}
std::vector<std::string> rhs_dims(
len_required(dnums.kernel_input_feature_dimension(),
dnums.kernel_output_feature_dimension(),
dnums.kernel_spatial_dimensions()),
"?");
rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
for (int64_t i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
}
std::vector<std::string> output_dims(
len_required(dnums.output_batch_dimension(),
dnums.output_feature_dimension(),
dnums.output_spatial_dimensions()),
"?");
output_dims[dnums.output_batch_dimension()] = 'b';
output_dims[dnums.output_feature_dimension()] = 'f';
for (int64_t i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
StrJoin(output_dims, ""));
}
std::string ReplicaGroupsToString(
absl::Span<const ReplicaGroup> replica_groups) {
std::vector<std::string> replica_group_str;
replica_group_str.reserve(replica_groups.size());
for (const ReplicaGroup& group : replica_groups) {
replica_group_str.push_back(
StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
}
return StrCat("{", StrJoin(replica_group_str, ","), "}");
}
StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const std::string& name) {
static absl::flat_hash_map<std::string, RandomAlgorithm>* map = [] {
static auto* map = new absl::flat_hash_map<std::string, RandomAlgorithm>;
for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
if (RandomAlgorithm_IsValid(i)) {
auto value = static_cast<RandomAlgorithm>(i);
(*map)[RandomAlgorithmToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown algorithm");
}
return found->second;
}
StatusOr<RandomDistribution> StringToRandomDistribution(
const std::string& name) {
static absl::flat_hash_map<std::string, RandomDistribution>* map = [] {
static auto* map = new absl::flat_hash_map<std::string, RandomDistribution>;
for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
if (RandomDistribution_IsValid(i)) {
auto value = static_cast<RandomDistribution>(i);
(*map)[RandomDistributionToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
return found->second;
}
StatusOr<PrecisionConfig::Precision> StringToPrecision(
const std::string& name) {
static absl::flat_hash_map<std::string, PrecisionConfig::Precision>* map =
[] {
static auto* map =
new absl::flat_hash_map<std::string, PrecisionConfig::Precision>;
for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
if (PrecisionConfig::Precision_IsValid(i)) {
auto value = static_cast<PrecisionConfig::Precision>(i);
(*map)[PrecisionToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
return found->second;
}
StatusOr<CustomCallSchedule> StringToCustomCallSchedule(
absl::string_view name) {
static const absl::flat_hash_map<std::string, CustomCallSchedule>* map = [] {
static auto* map = new absl::flat_hash_map<std::string, CustomCallSchedule>;
for (int i = 0; i < CustomCallSchedule_ARRAYSIZE; i++) {
if (CustomCallSchedule_IsValid(i)) {
auto value = static_cast<CustomCallSchedule>(i);
(*map)[CustomCallScheduleToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown schedule");
}
return found->second;
}
StatusOr<CustomCallApiVersion> StringToCustomCallApiVersion(
absl::string_view name) {
static const absl::flat_hash_map<std::string, CustomCallApiVersion>* map =
[] {
static auto* map =
new absl::flat_hash_map<std::string, CustomCallApiVersion>;
for (int i = 0; i < CustomCallApiVersion_ARRAYSIZE; i++) {
if (CustomCallApiVersion_IsValid(i)) {
auto value = static_cast<CustomCallApiVersion>(i);
(*map)[CustomCallApiVersionToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown API version");
}
return found->second;
}
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
const HloInstruction* const& rhs) const {
if (rhs == nullptr) {
// Nothing compares less than nullptr.
return false;
}
if (lhs == nullptr) {
return true;
}
auto lhs_module = lhs->GetModule();
auto rhs_module = rhs->GetModule();
CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
(lhs_module != nullptr && rhs_module != nullptr));
if (lhs_module != nullptr &&
lhs_module->unique_id() != rhs_module->unique_id()) {
return lhs_module->unique_id() < rhs_module->unique_id();
}
return lhs->unique_id() < rhs->unique_id();
}
Status HloInstruction::GetBackendConfigInternal(
tensorflow::protobuf::Message* proto) const {
proto->Clear();
if (auto* proto_ptr = backend_config_.GetProtoPtr()) {
if (proto_ptr->GetDescriptor() == proto->GetDescriptor()) {
proto->CopyFrom(*proto_ptr);
return OkStatus();
}
}
auto& raw_string = raw_backend_config_string();
// Empty string does not parse as valid JSON, but it's a valid backend config,
// corresponding to the empty proto.
if (raw_string.empty()) {
return OkStatus();
}
TF_RETURN_IF_ERROR(tensorflow::HumanReadableJsonToProto(raw_string, proto));
backend_config_.SetProto(*proto);
return OkStatus();
}
const std::string& HloInstruction::BackendConfigRep::GetRawString() const {
if (proto_ && raw_string_.empty()) {
raw_string_ = BackendConfigToRawString(*proto_).ValueOrDie();
}
return raw_string_;
}
HloInstruction::BackendConfigRep HloInstruction::BackendConfigRep::Clone()
const {
// Prefer cloning protobuf, raw_string_ will be lazily generated if accessed.
BackendConfigRep cloned;
if (auto* proto = GetProtoPtr()) {
cloned.SetProto(*proto);
} else {
cloned.raw_string_ = raw_string_;
}
return cloned;
}
HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
std::string raw_string) {
raw_string_ = std::move(raw_string);
proto_.reset();
return *this;
}
HloInstruction::BackendConfigRep& HloInstruction::BackendConfigRep::operator=(
const tensorflow::protobuf::Message& proto) {
SetProto(proto);
raw_string_.clear();
return *this;
}
void HloInstruction::BackendConfigRep::SetProto(
const tensorflow::protobuf::Message& proto) {
proto_.reset(proto.New());
proto_->CopyFrom(proto);
}
bool HloInstruction::BackendConfigRep::operator==(
const BackendConfigRep& other) const {
auto* proto_a = GetProtoPtr();
auto* proto_b = other.GetProtoPtr();
if (proto_a != nullptr && proto_b != nullptr) {
using ::tensorflow::protobuf::util::MessageDifferencer;
return MessageDifferencer::Equals(*proto_a, *proto_b);
}
// TODO(b/225956414): Consider canonicalizing raw string form.
return GetRawString() == other.GetRawString();
}
/* static */ StatusOr<std::string> HloInstruction::BackendConfigToRawString(
const tensorflow::protobuf::Message& proto) {
std::string ret;
// Pass ignore_accuracy_loss = true because estimated_cycles field can be
// INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
// INT64_MAX, JsonFormat will return an error status, although there is no
// accuracy loss for int64_t.
TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
proto, &ret, /*ignore_accuracy_loss=*/true));
return ret;
}
const PrecisionConfig& HloInstruction::precision_config() const {
if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->precision_config();
}
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->precision_config();
}
if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->precision_config();
}
LOG(FATAL) << "Unimplemented method.";
}
PrecisionConfig* HloInstruction::mutable_precision_config() {
if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->mutable_precision_config();
}
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->mutable_precision_config();
}
LOG(FATAL) << "Unimplemented method.";
}
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
}
return nullptr;
}
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
std::string parent_str = parent() == nullptr ? "noparent" : parent()->name();
name_ = name_uniquer->GetUniqueName(name_);
}
void HloInstruction::set_outer_dimension_partitions(
const std::vector<int64_t>& outer_dimension_partitions) {
outer_dimension_partitions_ = outer_dimension_partitions;
}
void HloInstruction::SortInstructionUsersAndControlLists(
const MappedPtrContainerSorter<HloInstruction>::MapPtrFn& map_fn,
const HloInstruction& sorted_instruction) {
using Sorter = MappedPtrContainerSorter<HloInstruction>;
auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
sorted_instruction.users_, users_);
if (!status.ok()) {
LOG(ERROR) << "Failed to sort instruction users for " << name() << "; "
<< status;
}
user_map_.clear();
for (uint64_t i = 0; i < users_.size(); ++i) {
user_map_[users_[i]] = i;
}
status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
sorted_instruction.control_predecessors_,
control_predecessors_);
if (!status.ok()) {
LOG(ERROR) << "Failed to sort instruction control predecessors for "
<< name() << "; " << status;
}
status =
Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
sorted_instruction.control_successors_, control_successors_);
if (!status.ok()) {
LOG(ERROR) << "Failed to sort instruction control successors for " << name()
<< "; " << status;
}
}
// TODO(b/80131774): Remove these temporary methods after transition.
int64_t HloInstruction::feature_index() const {
return Cast<HloBatchNormInstruction>(this)->feature_index();
}
float HloInstruction::epsilon() const {
return Cast<HloBatchNormInstruction>(this)->epsilon();
}
FftType HloInstruction::fft_type() const {
return Cast<HloFftInstruction>(this)->fft_type();
}
const std::vector<int64_t>& HloInstruction::fft_length() const {
return Cast<HloFftInstruction>(this)->fft_length();
}
int64_t HloInstruction::concatenate_dimension() const {
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
}
int64_t HloInstruction::dimension() const {
if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
return set_size->dimension();
}
return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
}
int64_t HloInstruction::inferred_dimension() const {
return Cast<HloReshapeInstruction>(this)->inferred_dimension();
}
bool HloInstruction::IsRank2Transpose() const {
auto transpose = DynCast<HloTransposeInstruction>(this);
return transpose != nullptr && transpose->IsRank2Transpose();
}
int64_t HloInstruction::slice_starts(int64_t dimension) const {
return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
}
const std::vector<int64_t>& HloInstruction::slice_starts() const {
return Cast<HloSliceInstruction>(this)->slice_starts();
}
std::vector<int64_t>* HloInstruction::mutable_slice_starts() {
return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
}
int64_t HloInstruction::slice_limits(int64_t dimension) const {
return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
}
const std::vector<int64_t>& HloInstruction::slice_limits() const {
return Cast<HloSliceInstruction>(this)->slice_limits();
}
std::vector<int64_t>* HloInstruction::mutable_slice_limits() {
return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
}
int64_t HloInstruction::slice_strides(int64_t dimension) const {
return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
}
const std::vector<int64_t>& HloInstruction::slice_strides() const {
return Cast<HloSliceInstruction>(this)->slice_strides();
}
std::vector<int64_t>* HloInstruction::mutable_slice_strides() {
return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
}
const Literal& HloInstruction::literal() const {
return Cast<HloConstantInstruction>(this)->literal();
}
bool HloInstruction::IsConstant() const {
return DynCast<HloConstantInstruction>(this) != nullptr;
}
void HloInstruction::RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index) {
Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
}
// Delegates to HloCallableInstruction::AppendInstructionIntoCalledComputation.
HloInstruction* HloInstruction::AppendInstructionIntoCalledComputation(
HloInstruction* instruction_to_append, bool add_output) {
return Cast<HloCallableInstruction>(this)
->AppendInstructionIntoCalledComputation(instruction_to_append,
add_output);
}
HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
}
// Delegates to HloFusionInstruction::MergeFusionInstruction.
void HloInstruction::MergeFusionInstruction(
HloInstruction* instruction_to_merge) {
return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
Cast<HloFusionInstruction>(instruction_to_merge));
}
// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
void HloInstruction::MergeFusionInstructionIntoMultiOutput(
HloInstruction* instruction_to_merge) {
return Cast<HloFusionInstruction>(this)
->MergeFusionInstructionIntoMultiOutput(
Cast<HloFusionInstruction>(instruction_to_merge));
}
HloInstruction* HloInstruction::FuseInstruction(
HloInstruction* instruction_to_fuse) {
return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
}
HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
HloInstruction* instruction_to_fuse) {
return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
instruction_to_fuse);
}
HloComputation* HloInstruction::fused_instructions_computation() const {
return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
}
HloInstruction* HloInstruction::fused_expression_root() const {
return Cast<HloFusionInstruction>(this)->fused_expression_root();
}
const tensorflow::gtl::iterator_range<UnwrappingIterator<
std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
HloInstruction::fused_instructions() const {
return Cast<HloFusionInstruction>(this)->fused_instructions();
}
const tensorflow::gtl::iterator_range<
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
HloInstruction::fused_instructions() {
return Cast<HloFusionInstruction>(this)->fused_instructions();
}
int64_t HloInstruction::fused_instruction_count() const {
return Cast<HloFusionInstruction>(this)->fused_instruction_count();
}
HloInstruction* HloInstruction::fused_parameter(
int64_t parameter_number) const {
return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
}
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
return Cast<HloFusionInstruction>(this)->fused_parameters();
}
const bool HloInstruction::IsMultiOutputFusion() const {
const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
return fusion != nullptr && fusion->IsMultiOutputFusion();
}
HloInstruction::FusionKind HloInstruction::fusion_kind() const {
return Cast<HloFusionInstruction>(this)->fusion_kind();
}
void HloInstruction::set_fusion_kind(FusionKind kind) {
return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
}
RandomDistribution HloInstruction::random_distribution() const {
return Cast<HloRngInstruction>(this)->random_distribution();
}
int64_t HloInstruction::parameter_number() const {
return Cast<HloParameterInstruction>(this)->parameter_number();
}
void HloInstruction::set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
return Cast<HloParameterInstruction>(this)
->set_parameter_replicated_at_leaf_buffers(
parameter_replicated_at_leaf_buffers);
}
void HloInstruction::set_parameter_replicated_at_leaf_buffers(
const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
return Cast<HloParameterInstruction>(this)
->set_parameter_replicated_at_leaf_buffers(
parameter_replicated_at_leaf_buffers);
}
const std::optional<std::vector<bool>>&
HloInstruction::parameter_replicated_at_leaf_buffers() const {
return Cast<HloParameterInstruction>(this)
->parameter_replicated_at_leaf_buffers();
}
int64_t HloInstruction::tuple_index() const {
return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
}
void HloInstruction::set_tuple_index(int64_t new_tuple_index) {
return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
new_tuple_index);
}
int32_t HloInstruction::exponent_bits() const {
return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
}
int32_t HloInstruction::mantissa_bits() const {
return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
}
std::string HloInstruction::infeed_config() const {
return Cast<HloInfeedInstruction>(this)->infeed_config();
}
void HloInstruction::set_infeed_config(const std::string& config) {
return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
}
const Shape& HloInstruction::outfeed_shape() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
}
Shape* HloInstruction::mutable_outfeed_shape() {
return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
}
const std::string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
}
void HloInstruction::set_outfeed_config(const std::string& config) {
return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
}
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
const std::vector<std::pair<int64_t, int64_t>>&
HloInstruction::source_target_pairs() const {
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
}
std::optional<int64_t> HloInstruction::channel_id() const {
return Cast<HloChannelInstruction>(this)->channel_id();
}
void HloInstruction::set_channel_id(const std::optional<int64_t>& channel_id) {
return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
}
const ConvolutionDimensionNumbers&
HloInstruction::convolution_dimension_numbers() const {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->convolution_dimension_numbers();
}
if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->convolution_dimension_numbers();
}
LOG(FATAL) << "Unimplemented method.";
}
void HloInstruction::set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
convolution->set_convolution_dimension_numbers(dnums);
} else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
custom_call->set_convolution_dimension_numbers(dnums);
} else {
LOG(FATAL) << "Unimplemented method.";
}
}
int64_t HloInstruction::feature_group_count() const {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->feature_group_count();
}
return Cast<HloCustomCallInstruction>(this)->feature_group_count();
}
void HloInstruction::set_feature_group_count(int64_t feature_group_count) {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->set_feature_group_count(feature_group_count);
}
Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
feature_group_count);
}
int64_t HloInstruction::batch_group_count() const {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->batch_group_count();
}
return Cast<HloCustomCallInstruction>(this)->batch_group_count();
}
void HloInstruction::set_batch_group_count(int64_t batch_group_count) {
if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
return convolution->set_batch_group_count(batch_group_count);
}
Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
batch_group_count);
}
HloComputation* HloInstruction::select() const {
return Cast<HloSelectAndScatterInstruction>(this)->select();
}
HloComputation* HloInstruction::scatter() const {
return Cast<HloSelectAndScatterInstruction>(this)->scatter();
}
void HloInstruction::set_select(HloComputation* computation) {
return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
}
void HloInstruction::set_scatter(HloComputation* computation) {
return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
}
const std::string& HloInstruction::custom_call_target() const {
return Cast<HloCustomCallInstruction>(this)->custom_call_target();
}
void HloInstruction::set_custom_call_target(absl::string_view target) {
Cast<HloCustomCallInstruction>(this)->set_custom_call_target(target);
}
const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
PaddingType HloInstruction::padding_type() const {
return Cast<HloCustomCallInstruction>(this)->padding_type();
}
PaddingConfig* HloInstruction::mutable_padding_config() {
return Cast<HloPadInstruction>(this)->mutable_padding_config();
}
int64_t HloInstruction::slice_sizes(int64_t dimension) const {
return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
}
const std::vector<int64_t>& HloInstruction::dynamic_slice_sizes() const {
return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
}
const std::vector<std::vector<int64_t>>&
HloInstruction::dynamic_slice_sizes_list() const {
return Cast<HloCollectivePermuteInstruction>(this)
->dynamic_slice_sizes_list();
}
const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
absl::Span<const int64_t> HloInstruction::gather_slice_sizes() const {
return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
const {
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
}
const DomainMetadata& HloInstruction::operand_side_metadata() const {
return Cast<HloDomainInstruction>(this)->operand_side_metadata();
}
const DomainMetadata& HloInstruction::user_side_metadata() const {
return Cast<HloDomainInstruction>(this)->user_side_metadata();
}
bool HloInstruction::IsAsynchronous() const {
return opcode() == HloOpcode::kAsyncStart ||
opcode() == HloOpcode::kAsyncUpdate ||
opcode() == HloOpcode::kAsyncDone;
}
HloComputation* HloInstruction::async_wrapped_computation() const {
CHECK(IsAsynchronous());
return called_computations()[0];
}
HloInstruction* HloInstruction::async_wrapped_instruction() const {
return Cast<HloAsyncInstruction>(this)->async_wrapped_instruction();
}
HloOpcode HloInstruction::async_wrapped_opcode() const {
return Cast<HloAsyncInstruction>(this)->async_wrapped_opcode();
}
std::optional<int64_t> HloInstruction::async_group_id() const {
return Cast<HloAsyncInstruction>(this)->async_group_id();
}
void HloInstruction::set_async_group_id(std::optional<int64_t> async_group_id) {
Cast<HloAsyncInstruction>(this)->set_async_group_id(async_group_id);
}
bool HloInstruction::is_cross_program_prefetch() const {
return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
}
ComparisonDirection HloInstruction::comparison_direction() const {
return Cast<HloCompareInstruction>(this)->direction();
}
ComparisonOrder HloInstruction::comparison_order() const {
return Cast<HloCompareInstruction>(this)->order();
}
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
}
const CholeskyOptions& HloInstruction::cholesky_options() const {
return Cast<HloCholeskyInstruction>(this)->cholesky_options();
}
} // namespace xla