blob: 0346e9077a03f0241069035b2024b76be3c2c351 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include <set>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
bool IsCallerInstruction(HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kWhile:
case HloOpcode::kAllReduce:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSort:
case HloOpcode::kFusion:
case HloOpcode::kCustomCall:
return true;
default:
return false;
}
}
namespace {
Status CheckOperandCount(const HloInstruction* hlo, int expected) {
if (hlo->operand_count() != expected) {
return InternalError("Expected %d operands for %s instruction: %s",
expected, HloOpcodeString(hlo->opcode()),
hlo->ToString());
}
return Status::OK();
}
Status CheckParameterCount(const HloInstruction* calling_instruction,
const HloComputation* computation, int expected) {
if (computation->num_parameters() != expected) {
return InternalError(
"Expected computation %s called from %s to have %d parameters, has %d",
computation->name(), calling_instruction->name(), expected,
computation->num_parameters());
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
return InternalError(
"Called computations specified for non-caller instruction %s",
hlo->ToString());
}
absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
if (arity) {
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
}
return Status::OK();
}
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
return CheckUnaryShape(hlo);
}
Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
return CheckBinaryShape(hlo);
}
Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
return CheckTernaryShape(clamp);
}
Status ShapeVerifier::HandleSelect(HloInstruction* select) {
return CheckTernaryShape(select);
}
Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
return CheckTernaryShape(tuple_select);
}
Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : concatenate->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(concatenate,
ShapeInference::InferConcatOpShape(
operand_shapes, concatenate->concatenate_dimension()));
}
Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
return CheckShape(convert, ShapeInference::InferConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
}
Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
}
Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
return CheckUnaryShape(copy);
}
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
TF_ASSIGN_OR_RETURN(const Shape expected,
ShapeInference::InferDotOpShape(
dot->operand(0)->shape(), dot->operand(1)->shape(),
dot->dot_dimension_numbers()));
return CheckShape(dot, expected);
}
Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
convolution->feature_group_count(), convolution->batch_group_count(),
convolution->window(), convolution->convolution_dimension_numbers()));
return CheckShape(convolution, expected);
}
Status ShapeVerifier::HandleFft(HloInstruction* fft) {
TF_ASSIGN_OR_RETURN(
const Shape expected,
ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
fft->fft_length()));
return CheckShape(fft, expected);
}
Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(const Shape expected,
ShapeInference::InferTriangularSolveShape(
hlo->operand(0)->shape(), hlo->operand(1)->shape(),
hlo->triangular_solve_options()));
return CheckShape(hlo, expected);
}
Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
hlo->operand(0)->shape()));
return CheckShape(hlo, expected);
}
// Checks that `hlo`'s set of ReplicaGroups:
//
// - names each replica 0 through n-1 exactly once, and
// - does not contain any empty ReplicaGroups.
//
// Note that although none of the groups may be empty, `hlo` is allowed to have
// 0 groups. That just means it has one big group.
//
// This is just a minimal set of checks; some instructions may have additional
// requirements. For example, all-to-all requires that all ReplicaGroups have
// the same number of replicas, but that isn't checked here.
static Status CheckReplicaGroups(HloInstruction* hlo) {
std::set<int64> replicas_seen;
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids().empty()) {
return InternalError("Instruction cannot have an empty replica group: %s",
hlo->ToString());
}
for (int64 i : g.replica_ids()) {
if (!replicas_seen.insert(i).second) {
return InternalError(
"Replica %d is repeated in instruction's replica-groups: %s", i,
hlo->ToString());
}
}
}
for (int64 i = 0; i < replicas_seen.size(); ++i) {
if (!replicas_seen.count(i)) {
return InternalError(
"Replica %d is not named in instruction's replica-groups: %s", i,
hlo->ToString());
}
}
// When the channel_id() or use_global_device_ids() is set, device ids in
// ReplicaGroup config no longer only mean replica ids. So we skip the check
// on the replica count.
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
if (channel_instr->channel_id()) {
return Status::OK();
}
}
if (auto all_reduce = DynCast<HloAllReduceInstruction>(hlo)) {
if (all_reduce->use_global_device_ids()) {
return Status::OK();
}
}
int64 replica_count = hlo->GetModule()->config().replica_count();
if (!replicas_seen.empty() && replicas_seen.size() != replica_count) {
return InternalError(
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
"contains %d replicas: %s",
replica_count, replicas_seen.size(), hlo->ToString());
}
return Status::OK();
}
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
auto ag = Cast<HloAllGatherInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
if (ag->use_global_device_ids() && ag->replica_groups().empty()) {
return InternalError(
"Replica group must be specified when use_global_device_ids is true");
}
int64 shard_count = CeilOfRatio(
ag->shape().dimensions(ag->all_gather_dimension()),
ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
if (ag->channel_id().has_value()) {
if (ag->use_global_device_ids()) {
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
} else {
if (ag->replica_groups().empty() ||
ag->replica_groups()[0].replica_ids_size() != 1) {
return InternalError(
"Replica group size must be 1 when use_global_device_ids is "
"false if the all-gather is also cross-partition");
}
}
} else if (!ag->replica_groups().empty()) {
// Cross-replica all-gather: shard count is subgroup size.
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
}
return CheckShape(ag, ShapeInference::InferAllGatherShape(
ag->operand(0)->shape(), ag->all_gather_dimension(),
shard_count));
}
Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
TF_RETURN_IF_ERROR(CheckReplicaGroups(crs));
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : crs->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
}
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo));
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
TF_RET_CHECK(all_to_all != nullptr);
if (all_to_all->split_dimension()) {
if (hlo->replica_groups().empty()) {
return InternalError(
"An array all-to-all must have an explicit replica_groups config");
}
}
// The size of each replica group must be the same (the split count of the
// operaion). In case the default replica group is used (empty replica group,
// must not be an array all-to-all, as checked above), infer from the number
// of operands.
const int64 split_count = hlo->replica_groups().empty()
? hlo->operand_count()
: hlo->replica_groups()[0].replica_ids_size();
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids_size() != split_count) {
return InternalError(
"Replica group has size %d, but all replica groups in an all-to-all "
"must have size N: %s",
g.replica_ids_size(), hlo->ToString());
}
}
if (all_to_all->split_dimension()) {
TF_RET_CHECK(hlo->operand_count() == 1);
return CheckShape(
hlo, ShapeInference::InferAllToAllShape(
hlo->operand(0)->shape(), *all_to_all->split_dimension(),
*all_to_all->split_dimension(), split_count));
} else {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : hlo->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(hlo,
ShapeInference::InferAllToAllTupleShape(operand_shapes));
}
}
Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
}
Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
}
namespace {
Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
// A source or target cannot appear twice in the collective-permute's
// source-target pairs.
absl::flat_hash_set<int64> seen_sources;
absl::flat_hash_set<int64> seen_targets;
for (const auto& p : hlo->source_target_pairs()) {
if (!seen_sources.insert(p.first).second) {
return InternalError(
"Source %d appears more than once in instruction's source-target "
"pairs: %s",
p.first, hlo->ToString());
}
if (!seen_targets.insert(p.second).second) {
return InternalError(
"Target %d appears more than once in instruction's source-target "
"pairs: %s",
p.second, hlo->ToString());
}
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
hlo->operand(0)->shape()));
}
Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
return CheckShape(
hlo, ShapeUtil::MakeTupleShape(
{hlo->operand(0)->shape(), hlo->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
}
Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
return CheckShape(
hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
}
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
reduce_precision->exponent_bits(),
reduce_precision->mantissa_bits()));
}
Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
"Expected operand %d to be token-shaped, actual shape is "
"%s:\n%s",
operand_no, StringifyShape(token->shape()), instruction->ToString());
}
return Status::OK();
}
Status ShapeVerifier::CheckOperandAndParameter(
const HloInstruction* instruction, int64 operand_number,
const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
operand->ToString(), parameter->ToString(),
instruction->ToString());
}
return Status::OK();
}
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
// The output of infeed is a tuple containing the data value and a token.
return CheckShape(infeed,
ShapeUtil::MakeTupleShape(
{infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
"Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
StringifyShape(outfeed->operand(0)->shape()),
StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
const Shape& shape_1,
const Shape& result_shape) {
return ShapeUtil::SameElementType(shape_0, shape_1) &&
(ShapeUtil::SameElementType(shape_0, result_shape) ||
(allow_mixed_precision_ &&
ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
result_shape)));
}
Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
const Shape& shape_0 = instruction->operand(0)->shape();
const Shape& shape_1 = instruction->operand(1)->shape();
if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
return InternalError(
"Expected scalar types for the two operands of Rng instruction: %s",
instruction->ToString());
}
if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
return InternalError(
"Expected compatible element types for the result and the two operands"
" of Rng instruction: %s",
instruction->ToString());
}
PrimitiveType element_type = shape_0.element_type();
switch (instruction->random_distribution()) {
case RNG_UNIFORM:
if (!primitive_util::IsFloatingPointType(element_type) &&
!primitive_util::IsIntegralType(element_type) &&
element_type != PRED) {
return InternalError(
"Element type not supported."
" Expected element to be of floating point type, integral type or"
" predicate type for RngUniform: %s",
instruction->ToString());
}
break;
case RNG_NORMAL:
if (!primitive_util::IsFloatingPointType(element_type)) {
return InternalError(
"Element type not supported."
" Expected element to be FloatingPointType for RngNormal: %s",
instruction->ToString());
}
break;
default:
return InternalError(
"Invalid Rng distribution %s",
RandomDistribution_Name(instruction->random_distribution()));
}
return Status::OK();
}
Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) {
return InternalError(
"Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
hlo->shape().ToString());
}
if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
hlo->shape().tuple_shapes(0))) {
return InternalError(
"Expected state shape to match between input and output for "
"RngBitGenerator. Got %s vs. %s",
hlo->operand(0)->shape().ToString(),
hlo->shape().tuple_shapes(0).ToString());
}
return Status::OK();
}
Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
const Shape& result_shape = instruction->shape();
const Shape expected_shape = ShapeUtil::MakeShape(U64, {2});
if (!ShapeUtil::Compatible(result_shape, expected_shape)) {
return InternalError(
"Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ",
StringifyShape(expected_shape), StringifyShape(result_shape));
}
return Status::OK();
}
Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
return CheckShape(
reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
reverse->dimensions()));
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
if (sort->operand_count() < 1) {
return InternalError("Expected at least 1 operand for %s instruction: %s",
HloOpcodeString(sort->opcode()), sort->ToString());
}
HloComputation* compare = sort->to_apply();
// Check that the 'compare' computation returns a PRED.
Shape compare_shape = compare->root_instruction()->shape();
if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"The Sort compare computation shape does not lead to a scalar "
"predicate shape: %s",
StringifyShape(compare_shape));
}
// Check that the number of parameters of the 'compare' computation is
// correct.
TF_RETURN_IF_ERROR(
CheckParameterCount(sort, compare, sort->operand_count() * 2));
// Verify that the operands of the compare computation have the correct scalar
// shapes.
for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters();
++parameter_idx) {
int64 operand_idx = parameter_idx / 2;
Shape expected_scalar_shape = ShapeUtil::MakeShape(
sort->operand(operand_idx)->shape().element_type(), {});
Shape actual_parameter_shape =
compare->parameter_instruction(parameter_idx)->shape();
if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
actual_parameter_shape)) {
return InternalError(
"Expected the %lld-th parameter of the compare computation of sort "
"to have shape %s, but got %s",
parameter_idx, StringifyShape(expected_scalar_shape),
StringifyShape(actual_parameter_shape));
}
}
// Verify that all operand shapes have the same dimensions.
for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
sort->operand(operand)->shape())) {
return InternalError(
"Expected sort to have to have the same dimensions for all operands. "
"First operand shape is: %s\n, shape (operand index %lld) is: %s",
StringifyShape(sort->operand(0)->shape()), operand,
StringifyShape(sort->operand(operand)->shape()));
}
}
return CheckVariadicShape(sort);
}
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
return InternalError("Constant is required to have a valid literal: %s",
constant->ToString());
}
return CheckShape(constant, constant->literal().shape(),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
if (!iota->shape().IsArray()) {
return InternalError("Iota does not support non-array result.");
}
const int64 rank = iota->shape().rank();
if (rank == 0) {
return InternalError("Iota does not support scalars.");
}
int64 iota_dimension = iota->iota_dimension();
if (iota_dimension >= rank || iota_dimension < 0) {
return InternalError(
"The iota dimension cannot go beyond the operation rank or be "
"negative.");
}
PrimitiveType primitive_type = iota->shape().element_type();
if (!primitive_util::IsIntegralType(primitive_type) &&
!primitive_util::IsFloatingPointType(primitive_type) &&
!primitive_util::IsComplexType(primitive_type)) {
return InvalidArgument(
"Only support iota of integral, floating point or complex primitive "
"types, got %s",
PrimitiveType_Name(primitive_type));
}
return Status::OK();
}
Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
return CheckShape(get_tuple_element,
ShapeInference::InferGetTupleElementShape(
get_tuple_element->operand(0)->shape(),
get_tuple_element->tuple_index()));
}
namespace {
Status SameElementTypesForOperandsAndToApplyParameters(
const HloInstruction& instruction, int64 num_operands_to_check) {
const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
for (int i = 0; i < num_operands_to_check; ++i) {
const Shape& parameter_shape = to_apply.parameters(i);
const Shape& operand_shape = instruction.operands()[i]->shape();
if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
return InvalidArgument(
"Shape mismatch between to_apply computation"
" parameter and operand %d in %s.",
i, instruction.ToString().c_str());
}
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
if (reduce->operand_count() % 2 != 0) {
return InternalError(
"Expected an even number of operands for %s instruction: %s",
HloOpcodeString(reduce->opcode()), reduce->ToString());
}
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : reduce->operands()) {
operand_shapes.push_back(&operand->shape());
}
TF_RETURN_IF_ERROR(
CheckShape(reduce, ShapeInference::InferReduceShape(
operand_shapes, reduce->dimensions(),
reduce->to_apply()->ComputeProgramShape())));
return allow_mixed_precision_
? Status::OK()
: SameElementTypesForOperandsAndToApplyParameters(
*reduce, reduce->operands().size() - 1);
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
if (layout_sensitive_ &&
shape_size_function_(bitcast->shape()) !=
shape_size_function_(bitcast->operand(0)->shape())) {
return InternalError(
"Bitcast cannot have different shape sizes of output (%d) and operand "
"(%d) (%s) (%s)",
shape_size_function_(bitcast->shape()),
shape_size_function_(bitcast->operand(0)->shape()),
bitcast->shape().ToString(true),
bitcast->operand(0)->shape().ToString(true));
}
return Status::OK();
}
Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
// HLO broadcast has no exact analog at the proto level so there is no
// ShapeInference method. Check the output shape explicitly.
const Shape& operand_shape = broadcast->operand(0)->shape();
// Check for mixed precision.
TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
output_dimension >= 0 &&
(broadcast->shape().dimensions(output_dimension) ==
operand_shape.dimensions(operand_dimension)))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return Status::OK();
}
Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
// Check for mixed precision.
const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
ShapeUtil::ElementsIn(operand_shape));
TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
dynamic_reshape->operand_count());
for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) {
TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
}
return Status::OK();
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
// Check for mixed precision.
const Shape& operand_shape = reshape->operand(0)->shape();
TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
ShapeUtil::ElementsIn(operand_shape));
return Status::OK();
}
Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
return CheckShape(
transpose, ShapeInference::InferTransposeShape(
transpose->operand(0)->shape(), transpose->dimensions()));
}
Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
if (fusion->called_computations().size() != 1) {
return InternalError(
"Fusion has a non-unary number of called computations (%s)",
fusion->ToString().c_str());
}
const Shape& root_computation_shape =
fusion->called_computations()[0]->root_instruction()->shape();
if (!ShapesSame(fusion->shape(), root_computation_shape)) {
return InternalError(
"Fused computation shape (%s) is not equal to the fusion shape (%s)",
root_computation_shape.ToString(true), fusion->shape().ToString(true));
}
auto& fused_parameters = fusion->fused_parameters();
if (fused_parameters.size() != fusion->operand_count()) {
return InternalError(
"Fused parameter count (%d) does not match the number of operands (%d)"
" passed to the fusion instruction in: %s.",
fused_parameters.size(), fusion->operand_count(),
fusion->ToString().c_str());
}
for (HloInstruction* fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
return InternalError(
"Shape mismatch between parameter number %d and its operand in "
"%s.",
param_no, fusion->ToString().c_str());
}
}
return Status::OK();
}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
TF_RETURN_IF_ERROR(
CheckParameterCount(call, call->to_apply(), call->operand_count()));
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
}
// The shape of kCall should match the shape of the computation it calls.
return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
const HloCustomCallInstruction* custom_call =
DynCast<const HloCustomCallInstruction>(instruction);
TF_RET_CHECK(custom_call != nullptr);
if (custom_call->layout_constrained()) {
// If the layout is constrained, verify all the respective shapes have
// layouts and that the constrained operand shapes match the shapes of the
// operands.
TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
TF_RET_CHECK(custom_call->operand_count() ==
custom_call->operand_shapes_with_layout().size());
for (int64 i = 0; i < custom_call->operand_count(); ++i) {
const Shape& operand_shape_with_layout =
custom_call->operand_shapes_with_layout()[i];
TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
operand_shape_with_layout))
<< custom_call->operand(i)->shape().ToString() << " operand "
<< operand_shape_with_layout.ToString();
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
}
}
return Status::OK();
}
Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
return CheckShape(slice,
ShapeInference::InferSliceShape(
slice->operand(0)->shape(), slice->slice_starts(),
slice->slice_limits(), slice->slice_strides()));
}
Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
return CheckShape(
dynamic_slice,
ShapeInference::InferDynamicSliceShape(
dynamic_slice->operand(0)->shape(),
Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
dynamic_slice->dynamic_slice_sizes()));
}
Status ShapeVerifier::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) {
return CheckShape(
dynamic_update_slice,
ShapeInference::InferDynamicUpdateSliceShape(
dynamic_update_slice->operand(0)->shape(),
dynamic_update_slice->operand(1)->shape(),
Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
->index_shapes()));
}
Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
return CheckVariadicShape(tuple);
}
Status ShapeVerifier::HandleMap(HloInstruction* map) {
std::vector<const Shape*> operand_shapes;
int64 max_operand_rank = 0;
for (const HloInstruction* operand : map->operands()) {
operand_shapes.push_back(&operand->shape());
max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
}
// TODO(b/65689298) Remove code below once Map is generalized to accept
// arbitrary map dimensions.
std::vector<int64> map_dims(max_operand_rank);
std::iota(map_dims.begin(), map_dims.end(), 0);
TF_RETURN_IF_ERROR(CheckShape(
map,
ShapeInference::InferMapShape(
operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
return allow_mixed_precision_
? Status::OK()
: SameElementTypesForOperandsAndToApplyParameters(
*map, map->operands().size());
}
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(CheckShape(
reduce_window,
ShapeInference::InferReduceWindowShape(
reduce_window->operand(0)->shape(),
reduce_window->operand(1)->shape(), reduce_window->window(),
reduce_window->to_apply()->ComputeProgramShape())));
return allow_mixed_precision_
? Status::OK()
: SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
1);
}
Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
return CheckShape(
instruction,
ShapeInference::InferSelectAndScatterShape(
instruction->operand(0)->shape(),
instruction->select()->ComputeProgramShape(), instruction->window(),
instruction->operand(1)->shape(), instruction->operand(2)->shape(),
instruction->scatter()->ComputeProgramShape()));
}
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
TF_RETURN_IF_ERROR(
CheckParameterCount(xla_while, xla_while->while_body(), 1));
TF_RETURN_IF_ERROR(
CheckParameterCount(xla_while, xla_while->while_condition(), 1));
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
if (!ShapeUtil::Compatible(conditional_shape,
ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
StringifyShape(conditional_shape));
}
// The shape of kWhile should match the shape of the body computation it
// calls.
return CheckShape(xla_while,
xla_while->while_body()->root_instruction()->shape());
}
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) {
return InvalidArgument(
"The first operand of conditional must be a scalar. Got %s",
conditional->operand(0)->shape().DebugString());
}
const int num_branches = conditional->branch_count();
PrimitiveType operand0_type = conditional->operand(0)->shape().element_type();
if (operand0_type == PRED) {
TF_RET_CHECK(num_branches == 2);
} else {
if (operand0_type != S32) {
return InvalidArgument(
"The first operand of indexed conditional must be a scalar of S32. "
"Got"
" type %s.",
PrimitiveType_Name(operand0_type));
}
TF_RET_CHECK(num_branches >= 1);
}
TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
for (int j = 0; j < num_branches; ++j) {
TF_RETURN_IF_ERROR(CheckParameterCount(
conditional, conditional->branch_computation(j), 1));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, j + 1, conditional->branch_computation(j), 0));
TF_RETURN_IF_ERROR(CheckShape(
conditional,
conditional->branch_computation(j)->root_instruction()->shape()));
}
return Status::OK();
}
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
pad->operand(1)->shape(),
pad->padding_config()));
}
Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
return CheckShape(copy_start,
ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
copy_start->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {})}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
const Shape& operand_shape = copy_done->operand(0)->shape();
const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
if (!ShapesSame(dest_shape, src_shape,
/*minor_to_major_only=*/false,
/*ignore_memory_space=*/true)) {
return InternalError(
"Source and destination buffers in CopyDone arguments need to be the "
"same shape found %s and %s\n%s",
StringifyShape(dest_shape), StringifyShape(src_shape),
copy_done->ToString());
}
return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
copy_done->operand(0)->shape(), 0));
}
Status ShapeVerifier::HandleSend(HloInstruction* send) {
return CheckShape(send,
ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
return CheckShape(send_done, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
return CheckShape(
recv,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv->shape(), 0),
ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
return CheckShape(
recv_done,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleBatchNormTraining(
HloInstruction* batch_norm_training) {
return CheckShape(batch_norm_training,
ShapeInference::InferBatchNormTrainingShape(
batch_norm_training->operand(0)->shape(),
batch_norm_training->operand(1)->shape(),
batch_norm_training->operand(2)->shape(),
batch_norm_training->feature_index()));
}
Status ShapeVerifier::HandleBatchNormInference(
HloInstruction* batch_norm_inference) {
return CheckShape(batch_norm_inference,
ShapeInference::InferBatchNormInferenceShape(
batch_norm_inference->operand(0)->shape(),
batch_norm_inference->operand(1)->shape(),
batch_norm_inference->operand(2)->shape(),
batch_norm_inference->operand(3)->shape(),
batch_norm_inference->operand(4)->shape(),
batch_norm_inference->feature_index()));
}
Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
batch_norm_grad->operand(0)->shape(),
batch_norm_grad->operand(1)->shape(),
batch_norm_grad->operand(2)->shape(),
batch_norm_grad->operand(3)->shape(),
batch_norm_grad->operand(4)->shape(),
batch_norm_grad->feature_index()));
}
namespace {
// Checks that the instruction does not have mixed precision floating point
// inputs.
Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
switch (instruction->opcode()) {
// Allow-list the following opcodes for mixed-precision check, because
// they involve data pass through or grouping via tuples, where the
// precisions of buffers can be different.
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kAllReduce:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kFusion:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReducePrecision:
case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kSort:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
break;
default: {
PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
for (auto operand : instruction->operands()) {
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
operand->shape(),
[&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::ElementIsFloating(subshape)) {
return Status::OK();
}
if (fp_type == PRIMITIVE_TYPE_INVALID) {
fp_type = subshape.element_type();
} else if (fp_type != subshape.element_type()) {
return InternalError(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
instruction->ToString());
}
return Status::OK();
}));
}
}
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::HandleGather(HloInstruction* gather) {
return CheckShape(
gather,
ShapeInference::InferGatherShape(
gather->operand(0)->shape(), gather->operand(1)->shape(),
gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
return CheckShape(
scatter, ShapeInference::InferScatterShape(
scatter->operand(0)->shape(), scatter->operand(1)->shape(),
scatter->operand(2)->shape(),
scatter->to_apply()->ComputeProgramShape(),
scatter->scatter_dimension_numbers()));
}
Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : token->operands()) {
operand_shapes.push_back(&operand->shape());
}
return CheckShape(token, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
return CheckShape(add_dependency, add_dependency->operand(0)->shape());
}
Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
return CheckShape(get_size,
ShapeInference::InferGetDimensionSizeShape(
get_size->operand(0)->shape(), get_size->dimension()));
}
Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
return CheckShape(set_size,
ShapeInference::InferSetDimensionSizeShape(
set_size->operand(0)->shape(),
set_size->operand(1)->shape(), set_size->dimension()));
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape,
bool only_compare_minor_to_major_in_layout) {
// If allow_mixed_precision_ is false, check if there are operands with
// different precisions. We need this check because ShapeInference allows
// mixed precision inputs.
if (!allow_mixed_precision_) {
TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
}
// Check if the output shape matches the expected shape.
//
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
bool equal = [&] {
switch (instruction->opcode()) {
// The opcodes below can't have implicit layout conversions, nor can they
// implicitly transform f32 -> bf16. Fundamentally these are either
// reinterpreting existing data (e.g. kBitcast) or shuffling data around
// without modifying it (e.g. kGetTupleElement, kTupleSelect).
case HloOpcode::kBitcast:
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kCustomCall:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
return ShapesSame(instruction->shape(), inferred_shape,
only_compare_minor_to_major_in_layout);
// We allow arbitrary layout and f32->bf16 transformations on all other
// instructions, although this may be made more strict pending discussion
// in b/112709536.
default:
if (allow_mixed_precision_) {
return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
inferred_shape);
} else {
return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
}
}
}();
if (!equal) {
return InternalError(
"Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
instruction->ToString());
}
return Status::OK();
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const StatusOr<Shape>& inferred_shape_status) {
if (!inferred_shape_status.ok()) {
Status s = inferred_shape_status.status();
tensorflow::errors::AppendToMessage(&s, ", for instruction ",
instruction->ToString());
return s;
}
return CheckShape(instruction, inferred_shape_status.ValueOrDie());
}
Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
return CheckShape(instruction,
ShapeInference::InferUnaryOpShape(instruction->opcode(),
instruction->operand(0)));
}
Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
return CheckShape(
instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
instruction->operand(0),
instruction->operand(1)));
}
Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
return CheckShape(instruction,
ShapeInference::InferTernaryOpShape(
instruction->opcode(), instruction->operand(0),
instruction->operand(1), instruction->operand(2)));
}
Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
return CheckShape(instruction,
ShapeInference::InferVariadicOpShape(
instruction->opcode(), instruction->operands()));
}
Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
const HloComputation* computation = module.entry_computation();
const auto& layout = module.entry_computation_layout();
const ShapeLayout& result_layout = layout.result_layout();
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
result_layout.shape())) {
return InternalError(
"Shape of the root instruction of entry computation (%s) should be "
"compatible to one specified in module's entry computation layout (%s)",
ShapeUtil::HumanString(computation->root_instruction()->shape()),
ShapeUtil::HumanString(result_layout.shape()));
}
if (computation->num_parameters() != layout.parameter_count()) {
return InternalError(
"Number of parameters in entry computation layout (%d) must be same "
"as number of parameters of entry computation (%d)",
layout.parameter_count(), computation->num_parameters());
}
for (int i = 0; i < computation->num_parameters(); ++i) {
const HloInstruction* parameter = computation->parameter_instruction(i);
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
return InternalError(
"Shape of the entry computation parameter %d is %s should be "
"compatible to the one specified in module's entry computation "
"layout %s",
i, ShapeUtil::HumanString(parameter->shape()),
ShapeUtil::HumanString(layout.parameter_shape(i)));
}
}
return Status::OK();
}
string ComputationsToString(absl::Span<HloComputation* const> computations) {
return absl::StrJoin(computations, ",",
[](string* s, const HloComputation* computation) {
s->append(computation->name());
});
}
// Verifies various invariants about the structure of the HLO:
//
// (1) each instruction has a non-null parent() set to the HloComputation
// which
// contains it.
//
// (2) each computation has a non-null parent() set to the HloModule which
// contains it.
//
// (3) the operands of each instruction are in the same computation as the
// instruction.
Status VerifyHloStructure(HloModule* module) {
for (const HloComputation* computation : module->computations()) {
if (computation->parent() == nullptr) {
return InternalError("Computation %s has a null parent pointer",
computation->name());
}
if (computation->parent() != module) {
return InternalError(
"Computation %s parent() does not point to parent module",
computation->name());
}
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->parent() == nullptr) {
return InternalError("Instruction %s has a null parent pointer",
instruction->name());
}
if (instruction->parent() != computation) {
return InternalError(
"Instruction %s parent() does not point to parent computation",
instruction->name());
}
}
}
// Check that operands are in the same computation separately from verifying
// parent() correctness so conditions like a null HloInstruction::parent()
// are identified and reported explicitly above rather than reporting a
// mismatched operand.
for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
for (int i = 0; i < instruction->operand_count(); ++i) {
const HloInstruction* operand = instruction->operand(i);
if (operand->parent() != instruction->parent()) {
return InternalError(
"Operand %d (%s) of instruction %s is in a different "
"computation: %s vs %s",
i, operand->name(), instruction->name(),
operand->parent() ? operand->parent()->name() : "(null)",
instruction->parent()->name());
}
}
}
}
return Status::OK();
}
namespace {
// Returns true if the given Shape has a TOKEN shape as any subshape.
bool ShapeContainsToken(const Shape& shape) {
bool contains_token = false;
ShapeUtil::ForEachSubshape(
shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
if (subshape.IsToken()) {
contains_token = true;
}
});
return contains_token;
}
// Verifies that all types entering and exiting the entry computation are
// legal.
Status VerifyEntryAndExitShapes(const HloModule& module) {
// Tokens cannot be passed as entry parameters.
// TODO(b/80000000): Remove this constraint.
for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
HloInstruction* param =
module.entry_computation()->parameter_instruction(i);
if (ShapeContainsToken(param->shape())) {
return InternalError(
"Entry parameter %d is or contains a token shape: %s", i,
ShapeUtil::HumanString(param->shape()));
}
}
return Status::OK();
}
// Checks if the given two instructions share the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
if (instr1->channel_id() != instr2->channel_id()) {
return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
"(%d), %s (%d)",
instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
*instr2->channel_id());
}
return Status::OK();
}
// Checks if the given two instructions have the same is_host_transfer
// attribute value. Intsructions must be send/recv instructions or their
// 'done' variant.
Status CheckSameIsHostTransfer(const HloInstruction* instr1,
const HloInstruction* instr2) {
const HloSendRecvInstruction* send_recv1 =
DynCast<const HloSendRecvInstruction>(instr1);
const HloSendRecvInstruction* send_recv2 =
DynCast<const HloSendRecvInstruction>(instr2);
TF_RET_CHECK(send_recv1 != nullptr);
TF_RET_CHECK(send_recv2 != nullptr);
if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
return InternalError(
"Expected instructions to have the same is-host-transfer property: "
"%s, "
"%s ",
instr1->ToString(), instr2->ToString());
}
return Status::OK();
}
Status VerifySingleUser(const HloInstruction* instruction,
HloOpcode expected_user) {
TF_RET_CHECK(instruction->users().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* user = instruction->users().front();
TF_RET_CHECK(user->opcode() == expected_user)
<< "The consumer of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_user)
<< ", found " << HloOpcodeString(user->opcode());
return Status::OK();
}
Status VerifySingleOperand(const HloInstruction* instruction,
HloOpcode expected_operand) {
TF_RET_CHECK(instruction->operands().size() == 1)
<< "The " << HloOpcodeString(instruction->opcode())
<< " instruction requires one consumer, found "
<< instruction->users().size();
const HloInstruction* operand = instruction->operand(0);
TF_RET_CHECK(operand->opcode() == expected_operand)
<< "The operand of a " << HloOpcodeString(instruction->opcode())
<< " instruction needs to be " << HloOpcodeString(expected_operand)
<< ", found " << HloOpcodeString(operand->opcode());
return Status::OK();
}
// Checks asynchronous instruction pairs.
Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
// CopyStart must have a single CopyDone user.
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
switch (instruction->opcode()) {
case HloOpcode::kCopyStart: {
TF_RETURN_IF_ERROR(
VerifySingleUser(instruction, HloOpcode::kCopyDone));
break;
}
case HloOpcode::kCopyDone: {
TF_RETURN_IF_ERROR(
VerifySingleOperand(instruction, HloOpcode::kCopyStart));
break;
}
case HloOpcode::kCollectivePermuteStart: {
TF_RETURN_IF_ERROR(
VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
break;
}
case HloOpcode::kCollectivePermuteDone: {
TF_RETURN_IF_ERROR(VerifySingleOperand(
instruction, HloOpcode::kCollectivePermuteStart));
break;
}
default:
break;
}
}
}
return Status::OK();
}
// Checks that AllReduce instructions in the module are either all layout
// constrained or all unconstrained.
Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
const HloAllReduceInstruction* reference = nullptr;
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kAllReduce) {
continue;
}
auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
if (!reference) {
reference = all_reduce;
}
if (reference->constrain_layout() != all_reduce->constrain_layout()) {
return FailedPrecondition(
"HloModule has a mix of layout constrained and unconstrained "
"AllReduce instructions.");
}
}
}
return Status::OK();
}
// Checks various invariants of channel instructions (send/recv and
// collectives).
Status VerifyChannels(const HloModule& module) {
absl::flat_hash_map<int64, std::vector<const HloInstruction*>>
channel_instructions;
// Send/Recv instruction must have a single user: the corresponding
// SendDone/RecvDone. with matching channel.
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
auto channel_instr = DynCast<HloChannelInstruction>(instruction);
if (!channel_instr || !channel_instr->channel_id()) {
continue;
}
channel_instructions[*channel_instr->channel_id()].push_back(instruction);
switch (instruction->opcode()) {
case HloOpcode::kSend: {
TF_RET_CHECK(instruction->users().size() == 1);
const HloInstruction* send_done = instruction->users().front();
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
break;
}
case HloOpcode::kRecv: {
TF_RET_CHECK(instruction->users().size() == 1);
const HloInstruction* recv_done = instruction->users().front();
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
break;
}
case HloOpcode::kSendDone:
TF_RET_CHECK(instruction->operands().size() == 1);
TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
break;
case HloOpcode::kRecvDone:
TF_RET_CHECK(instruction->operands().size() == 1);
TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
break;
default:
break;
}
}
}
// Iterate over each channel to check invariants.
for (auto& pair : channel_instructions) {
auto& instructions = pair.second;
const HloInstruction* first = instructions[0];
auto sendrecv = DynCast<HloSendRecvInstruction>(first);
if (sendrecv) {
absl::flat_hash_set<HloOpcode> opcodes;
for (const HloInstruction* instr : instructions) {
opcodes.insert(instr->opcode());
auto cast = DynCast<HloSendRecvInstruction>(instr);
TF_RET_CHECK(cast != nullptr)
<< "channel " << pair.first
<< " is used for different types of channel instructions";
}
if (sendrecv->is_host_transfer()) {
TF_RET_CHECK(instructions.size() == 2)
<< "channel " << pair.first
<< " is used for multiple host send/recv instructions";
} else {
TF_RET_CHECK(instructions.size() == opcodes.size())
<< "channel " << pair.first
<< " is used for multiple send/recv instructions";
}
} else {
for (const HloInstruction* instr : instructions) {
TF_RET_CHECK(first->opcode() == instr->opcode())
<< "channel " << pair.first
<< " is used for different types of channel instructions";
}
}
}
return Status::OK();
}
// CHECKs various invariants of a fusion instruction.
Status CheckFusionInstruction(HloInstruction* fusion) {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
return InternalError(
"Instruction of fused computation does not match expected "
"instruction "
"%s.",
fusion->ToString());
}
// Fused root instruction and fused parameters must all be owned by the
// fusion computation.
bool root_owned = false;
const std::vector<HloInstruction*>& fused_parameters =
fusion->fused_parameters();
const HloInstruction* fused_root = fusion->fused_expression_root();
std::vector<bool> parameter_owned(fused_parameters.size(), false);
for (auto* instruction : fused_computation->instructions()) {
if (fused_root == instruction) {
if (root_owned) {
return InternalError("Root appears more than once in %s.",
fusion->ToString());
}
root_owned = true;
}
for (int i = 0; i < fused_parameters.size(); ++i) {
if (fused_parameters[i] == instruction) {
if (parameter_owned[i]) {
return InternalError("Parameter appears more than once in %s.",
fusion->ToString());
}
parameter_owned[i] = true;
}
}
}
if (!root_owned) {
return InternalError("Root not found in computation of %s.",
fusion->ToString());
}
// Make sure all the parameter_owned entries are set
for (int i = 0; i < parameter_owned.size(); i++) {
if (!parameter_owned[i]) {
return InternalError("Parameter %d not found in computation of %s.", i,
fusion->ToString());
}
}
// Fused root must have no users.
if (fused_root->user_count() != 0) {
return InternalError("Root of %s may not have users.", fusion->ToString());
}
// All uses of fused instructions must be in the fusion computation, and
// every non-root instruction must have at least one use.
for (auto* instruction :
fusion->fused_instructions_computation()->instructions()) {
if (instruction != fused_root) {
if (instruction->user_count() == 0) {
return InternalError("Non-root instruction %s in %s must have users.",
instruction->ToString(), fusion->ToString());
}
for (auto& user : instruction->users()) {
if (fused_computation != user->parent()) {
return InternalError(
"Non-root instruction %s in %s may not have external users.",
instruction->ToString(), fusion->ToString());
}
}
}
}
// Fused parameter instructions must be numbered contiguously and match up
// (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (param_no < 0) {
return InternalError("Unexpected negative parameter number %d in %s.",
param_no, fusion->ToString());
}
if (param_no >= fused_parameters.size()) {
return InternalError(
"Unexpected parameter number %d in %s: higher then number of "
"parameters %lu.",
param_no, fusion->ToString(), fused_parameters.size());
}
if (parameter_numbers[param_no]) {
return InternalError(
"Did not expect parameter number %d more than once in %s.", param_no,
fusion->ToString());
}
parameter_numbers[param_no] = true;
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
if (!parameter_numbers[i]) {
return InternalError("Did not see parameter number %d in %s.", i,
fusion->ToString());
}
}
TF_RET_CHECK(fusion->called_computations() ==
absl::Span<HloComputation* const>(
{fusion->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< fusion->ToString() << " fusion->fused_instructions_computation(): "
<< fusion->fused_instructions_computation()->ToString()
<< " fusion->called_computations(): "
<< ComputationsToString(fusion->called_computations());
for (const auto& fused : fusion->fused_instructions()) {
TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
<< "Fused HLO was missing a parent: " << fused->ToString()
<< " parent: " << fused->parent()
<< " computation: " << fusion->parent();
}
// TODO(b/65423525): We'd like to check that all operands are distinct.
// This is currently disabled due to the invariant being violated by
// multi-output fusion.
return Status::OK();
}
// Checks that the operand shapes are compatible to the output shape, i.e.,
// that there are no implicit broadcasts.
Status CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
"Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
HloOpcodeString(instruction->opcode()),
ShapeUtil::HumanString(out_shape),
ShapeUtil::HumanString(operand_shape));
}
}
return Status::OK();
}
// Visitor which verifies various fields on the HLO instruction. This class does
// not check result shape as that is checked in the ShapeVerifier.
class InstructionVerifier : public DfsHloVisitorWithDefault {
public:
explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func)
: instruction_can_change_layout_func_(
instruction_can_change_layout_func) {}
Status DefaultAction(HloInstruction*) override { return Status::OK(); }
Status HandleFusion(HloInstruction* fusion) override {
return CheckFusionInstruction(fusion);
}
Status HandleBroadcast(HloInstruction* broadcast) override {
// If you see this failure then someone has confused the difference
// between the HLO broadcast op, and the UserComputation broadcast
// op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
// or ComputationLowerer::Visit()
TF_RET_CHECK(broadcast->dimensions().size() ==
broadcast->operand(0)->shape().rank())
<< "Broadcast HLO (" << broadcast->ToShortString()
<< ") has invalid number of dimensions: "
<< broadcast->dimensions().size()
<< " != " << broadcast->operand(0)->shape().rank();
return Status::OK();
}
Status HandleWhile(HloInstruction* xla_while) override {
auto* while_cond = xla_while->while_condition();
auto* while_body = xla_while->while_body();
if (while_cond->num_parameters() != 1) {
return FailedPrecondition(
"While condition must have exactly 1 parameter; had %d : %s",
while_cond->num_parameters(), while_cond->ToString());
}
if (while_body->num_parameters() != 1) {
return FailedPrecondition(
"While body must have exactly 1 parameter; had %d : %s",
while_body->num_parameters(), while_body->ToString());
}
if (xla_while->operand_count() != 1) {
return FailedPrecondition(
"While loop must have exactly one operand; had %d : %s",
xla_while->operand_count(), xla_while->ToString());
}
return Status::OK();
}
Status HandleConditional(HloInstruction* conditional) override {
for (int b = 0; b < conditional->branch_count(); ++b) {
if (conditional->branch_computation(b)->num_parameters() != 1) {
return FailedPrecondition(
"Branch computation %s of %s must have 1 parameter instead of %d",
conditional->branch_computation(b)->name(), conditional->ToString(),
conditional->branch_computation(b)->num_parameters());
}
}
return Status::OK();
}
Status HandleElementwiseUnary(HloInstruction* instruction) override {
return CheckElementwiseInstruction(instruction);
}
Status HandleElementwiseBinary(HloInstruction* instruction) override {
return CheckElementwiseInstruction(instruction);
}
Status HandleGetTupleElement(HloInstruction* gte) override {
TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
return Status::OK();
}
Status HandleTranspose(HloInstruction* transpose) override {
const Shape& shape = transpose->shape();
const HloInstruction* operand = transpose->operand(0);
TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
TF_RET_CHECK(shape.dimensions().size() ==
transpose->operand(0)->shape().dimensions().size());
TF_RET_CHECK(std::equal(
operand->shape().dimensions().begin(),
operand->shape().dimensions().end(),
Permute(transpose->dimensions(), shape.dimensions()).begin()))
<< "shape: " << shape << ", operand->shape(): " << shape
<< ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
<< "}";
return Status::OK();
}
Status HandleAllReduce(HloInstruction* crs) override {
if (crs->channel_id().has_value()) {
TF_RET_CHECK(crs->channel_id().value() > 0)
<< "All reduce channel id must be greater than 0 for "
<< crs->ToShortString();
}
return Status::OK();
}
Status Preprocess(HloInstruction* instruction) override {
auto previous = instructions_by_name_.find(instruction->name());
TF_RET_CHECK(previous == instructions_by_name_.end())
<< "HLO has name that is not unique within module:\n"
<< instruction->ToString()
<< " in computation: " << instruction->parent()->name()
<< "\nPrevious HLO with same name:\n"
<< previous->second->ToString()
<< " in computation: " << previous->second->parent()->name();
instructions_by_name_[instruction->name()] = instruction;
return Status::OK();
}
Status Postprocess(HloInstruction* instruction) override {
if (instruction_can_change_layout_func_ &&
LayoutUtil::IsDenseArray(instruction->shape()) &&
!instruction_can_change_layout_func_(instruction)) {
const Shape& result_shape = instruction->shape();
const Layout& result_layout = result_shape.layout();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
if (LayoutUtil::IsDenseArray(operand_shape) &&
operand_shape.rank() == result_shape.rank()) {
const Layout& operand_layout = operand_shape.layout();
TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
<< "Instruction shouldn't change layouts "
<< instruction->ToString() << " From " << result_shape << " To "
<< operand_shape;
}
}
}
return Status::OK();
}
private:
absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
// Determines whether an instruction can change layouts.
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func_;
};
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RET_CHECK(!module->name().empty());
if (module->entry_computation()->IsFusionComputation()) {
return InvalidArgument(
"Module entry computation cannot be a fusion computation");
}
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
TF_RETURN_IF_ERROR(VerifyChannels(*module));
std::unique_ptr<ShapeVerifier> shape_verifier =
target_metadata_->GetVerifier();
InstructionVerifier instruction_verifier(instruction_can_change_layout_func_);
for (auto* computation : module->computations()) {
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}
TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
// If the module has a schedule, it must be valid.
if (module->has_schedule()) {
TF_RETURN_IF_ERROR(module->schedule().Verify());
}
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
*module, [this](const Shape& shape) -> int64 {
if (target_metadata_->IsLayoutSensitive()) {
return target_metadata_->ShapeSize(shape);
} else {
return 0;
}
}));
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
return false;
}
} // namespace xla