blob: c00b39b1645160f024ea7095f30dcdfd82f94c4a [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include <memory>
#include <utility>
#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
std::unique_ptr<HloModule> CreateUnverifiedModule() {
return std::make_unique<HloModule>("module", HloModuleConfig());
}
// This class cannot be converted to use HloTestBase. It explicitly
// uses HloTestBase to create and test malformed HLOs.
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
: HloTestBase(/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
: HloTestBase(/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
class HloVerifierTestLayoutSensitive : public HloTestBase {
public:
HloVerifierTestLayoutSensitive()
: HloTestBase(/*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false,
LayoutAssignment::InstructionCanChangeLayout) {}
};
class HloVerifierTestLayoutFusion : public HloTestBase {
public:
HloVerifierTestLayoutFusion()
: HloTestBase(/*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
HloInstruction* negate = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
auto module = CreateUnverifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(verifier().Run(module.get()).status());
negate->set_parent(nullptr);
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
}
TEST_F(HloVerifierTest, NullComputationParent) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
auto module = CreateUnverifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(verifier().Run(module.get()).status());
computation->set_parent(nullptr);
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
}
TEST_F(HloVerifierTest, DifferentOperandParents) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
HloInstruction* negate = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
auto module = CreateUnverifiedModule();
module->AddEntryComputation(builder.Build());
HloComputation::Builder emb_builder(TestName());
HloInstruction* emb_param = emb_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
module->AddEmbeddedComputation(emb_builder.Build());
TF_ASSERT_OK(verifier().Run(module.get()).status());
TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("is in a different computation"));
}
TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
HloComputation::Builder builder(TestName());
Shape s1 = ShapeUtil::MakeShape(F32, {1});
Shape s2 = ShapeUtil::MakeShape(F32, {2});
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
// Create an add instruction with the incorrect shape.
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
// In order to trigger the bug we're checking for, the instruction with the
// bad shape can't be the root of the computation.
builder.AddInstruction(
HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
auto module = CreateUnverifiedModule();
module->AddEntryComputation(builder.Build());
// Run the verifier twice. It should fail both times, because it shouldn't
// carry state in its DFS visitor between runs.
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
}
TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
const char* const hlo_string = R"(
HloModule Module
callme {
ROOT param = (s32[], f32[4]) parameter(0)
}
ENTRY entry {
p0 = (f32[4], s32[]) parameter(0)
ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("shape does not match parameter"));
}
TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
const char* const hlo_string = R"(
HloModule Module
true_branch {
tparam = (s32[], f32[4]) parameter(0)
ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
}
false_branch {
fparam = (s32[], f32[4]) parameter(0)
ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
}
ENTRY entry {
p0 = (f32[4], s32[]) parameter(0)
constant = pred[] constant(true)
ROOT conditional = f32[4] conditional(constant, p0, p0),
true_computation=true_branch, false_computation=false_branch
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("shape does not match parameter"));
}
TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
const char* const hlo_string = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
ROOT tgte1 = f32[4] ceil(tparam)
}
branch1 {
fparam = f32[4] parameter(0)
ROOT fgte1 = f32[4] floor(fparam)
}
branch2 {
sparam = f32[4] parameter(0)
ROOT sgte1 = f32[4] ceil(sparam)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = s32[] parameter(1)
ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
branch_computations={branch0, branch1, branch2}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
HloInstruction* condition = FindInstruction(module.get(), "b0");
*condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr(
"first operand of indexed conditional must be a scalar of S32"));
*condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("first operand of conditional must be a scalar"));
}
TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOpnd0NotScalar {
constant.0 = f32[] constant(0)
constant.1 = f16[2] constant({1, 3})
ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
distribution=rng_uniform
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
}
TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
constant.0 = f32[] constant(0)
constant.1 = f16[] constant(1)
ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
distribution=rng_normal
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected compatible element types"));
}
TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngResultElementTypeNotMatch {
constant.0 = f32[] constant(0)
constant.1 = f32[] constant(1)
ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
distribution=rng_normal
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected compatible element types"));
}
TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngResultElementTypeNotMatch {
constant.0 = f32[] constant(0)
constant.1 = f32[] constant(1)
ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
distribution=rng_normal
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngElementTypeNotSupported {
constant.0 = s32[] constant(0)
constant.1 = s32[] constant(1)
ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
distribution=rng_normal
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
}
TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
// This testcase can't be written using textual HLO, because it doesn't parse
// negative interior padding. That's probably a feature. :)
HloComputation::Builder builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {100}), "param"));
PaddingConfig padding_config;
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
padding_config));
auto module = CreateUnverifiedModule();
module->AddEntryComputation(builder.Build());
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Interior padding cannot be negative"));
}
TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
// This testcase can't be written using textual HLO, because it doesn't parse
// negative interior padding. That's probably a feature. :)
HloComputation::Builder builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {100}), "param"));
PaddingConfig padding_config;
padding_config.add_dimensions()->set_interior_padding(-1);
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {100}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
padding_config));
auto module = CreateUnverifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Interior padding cannot be negative"));
}
TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
static const char* const kDotHloString = R"(
HloModule module
ENTRY entry_computation {
a = f32[2,10] parameter(0)
b = bf16[10,2] parameter(1)
ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kDotHloString));
auto status = verifier().Run(module.get()).status();
EXPECT_TRUE(status.ok()) << status;
}
// Simple module containing a convolution as the root.
static const char* const kConvHloString = R"(
HloModule module
ENTRY entry_computation {
param0 = f16[128,128,56,56] parameter(0)
param1 = f16[3,3,128,128] parameter(1)
zero_f16 = f16[] constant(0)
ROOT conv = f16[128,128,28,28] convolution(param0, param1),
window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
})";
TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kConvHloString));
auto* conv = module->entry_computation()->root_instruction();
Window w = conv->window();
w.mutable_dimensions(0)->set_window_dilation(-1);
conv->set_window(w);
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("non-positive window dilation factor"));
}
TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kConvHloString));
auto* conv = module->entry_computation()->root_instruction();
Window w = conv->window();
w.mutable_dimensions(0)->set_base_dilation(-1);
conv->set_window(w);
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("non-positive base area dilation factor"));
}
static const char* const kAddWithLayoutChangeHlo = R"(
HloModule AddWithLayoutChange
ENTRY AddWithLayoutChange {
par0 = f32[3,4]{1,0} parameter(0)
par1 = f32[3,4]{0,1} parameter(1)
ROOT add0 = f32[3,4]{1,0} add(par0,par1)
}
)";
TEST_F(HloVerifierTest, AddWithLayoutChange) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
const char* const kScalarIndexDynamicSlice = R"(
HloModule DynamicSlice_module
ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
%original_parameter = s32[2,2,258] parameter(0)
%constant = s32[] constant(0)
%start_index = s32[] parameter(1)
ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
}
)";
HloModuleConfig config;
DebugOptions debug_options = config.debug_options();
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kScalarIndexDynamicSlice, config));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
const char* const kScalarIndexDynamicSlice = R"(
HloModule DynamicUpdateSlice_module
ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
%input = s32[1,1,25,1]{3,2,1,0} parameter(0)
%update = s32[1,1,2,1]{3,2,1,0} parameter(1)
%start_index.0 = s32[] parameter(2)
%start_index.1 = s32[] parameter(3)
%start_index.2 = s32[] parameter(4)
%start_index.3 = s32[] parameter(5)
ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
}
)";
HloModuleConfig config;
DebugOptions debug_options = config.debug_options();
debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kScalarIndexDynamicSlice, config));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
const char* const kDynamicUpdateSliceMixedPrecision = R"(
HloModule kDynamicUpdateSliceMixedPrecision
ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
%parameter.0 = f32[32,511,2048] parameter(0)
%parameter.1 = bf16[32,511,512] parameter(1)
%parameter.2 = s32[] parameter(2)
%parameter.3 = s32[] parameter(3)
%parameter.4 = s32[] parameter(4)
ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
kDynamicUpdateSliceMixedPrecision));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"f32[32,511,2048], actual shape is bf16[32,511,2048]"));
}
TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Instruction shouldn't change layouts"));
}
TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
const char* const kSliceWithLayoutChangeHlo = R"(
HloModule SliceWithLayoutChange
ENTRY SliceWithLayoutChange {
par0 = f32[4,5]{0,1} parameter(0)
par1 = s32[] parameter(1)
par2 = s32[] parameter(2)
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
dynamic_slice_sizes={3,4}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Instruction shouldn't change layouts"));
}
TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
const char* const kConcatWithLayoutChangeHlo = R"(
HloModule ConcatWithLayoutChange
ENTRY ConcatWithLayoutChange {
par0 = f32[3,5]{0,1} parameter(0)
par1 = f32[3,3]{1,0} parameter(1)
ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Instruction shouldn't change layouts"));
}
TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
const char* const hlo_string = R"(
HloModule Module
ENTRY BitcastNeedsToBeNoOp {
constant.0 = f32[2] constant({0.0, 0.0})
ROOT bitcast = f32[3] bitcast(constant.0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Bitcast cannot have different shape sizes of output "
"(12) and operand (8)"));
}
TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
const char* const hlo_string = R"(
HloModule Module
ENTRY SelectMixedPrecisionNotAllowed {
p0 = pred[32] parameter(0)
p1 = f32[32] parameter(1)
p2 = bf16[32] parameter(2)
ROOT select = f32[32] select(p0, p1, p2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Seen floating point types of different precisions"));
}
TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
const char* const hlo_string = R"(
HloModule Module
ENTRY SelectMixedPrecisionAllowed {
p0 = pred[32] parameter(0)
p1 = f32[32] parameter(1)
p2 = bf16[32] parameter(2)
ROOT select = f32[32] select(p0, p1, p2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
const char* const hlo_string = R"(
HloModule Module
ENTRY SelectWithTuple {
p0 = (f32[], f32[]) parameter(0)
p1 = (f32[], f32[]) parameter(1)
p2 = pred[] parameter(2)
ROOT select = (f32[], f32[]) select(p2, p0, p1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected array argument for select"));
}
TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
const char* const hlo_string = R"(
HloModule Module
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to"));
}
TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
const char* const hlo_string = R"(
HloModule Module
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3] parameter(0)
copy-start = f32[2,3] copy-start(p0)
ROOT copy-done = f32[2,3] copy-done(copy-start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"(f32[2,3], f32[2,3], u32[])"));
}
TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3] parameter(0)
copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
copy-done.1 = f32[2,3] copy-done(copy-start)
copy-done.2 = f32[2,3] copy-done(copy-start)
ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("copy-start instruction requires one consumer, found 2"));
}
TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
const char* const hlo_string = R"(
HloModule Module
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3] parameter(0)
p1 = u32[] parameter(1)
tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
ROOT copy-done = f32[2,3] copy-done(tuple)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a copy-done instruction needs to be "
"copy-start, found tuple"));
}
TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-start), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncUpdateAndAsyncDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), custom_call_target="foo"
ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTestLayoutSensitive,
AsyncStartAndAsyncUpdateAndAsyncDoneWithThreadName) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_thread_name="parallel_thread", custom_call_target="foo"
async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_thread_name="parallel_thread", custom_call_target="foo"
async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_thread_name="parallel_thread", custom_call_target="foo"
ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_thread_name="parallel_thread", custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongType) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[3,2], u32[]) custom-call-start(p0), custom_call_target="foo"
ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-done expects the async shape at index {1} to "
"match the async computation root shape"));
}
TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongThreadName) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), async_thread_name="parallel_thread", custom_call_target="foo"
ROOT async-done = f32[2,3] custom-call-done(async-start), async_thread_name="main_thread", custom_call_target="bar"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("thread name (main_thread vs parallel_thread)."));
}
TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongAttr) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="bar"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-done expects its wrapped async computation to "
"be identical to its operand's"));
}
TEST_F(HloVerifierTest, AsyncStartMultipleAsyncDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
async-done.1 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
async-done.2 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
ROOT tuple = (f32[2,3], f32[2,3]) tuple(async-done.1, async-done.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("async-start instruction requires one consumer, found 2"));
}
TEST_F(HloVerifierTest, AsyncStartNoAsyncDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
ROOT async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("async-start instruction requires one consumer, found 0"));
}
TEST_F(HloVerifierTest, AsyncStartAndAsyncUpdateNoAsyncDone) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
ROOT async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("async-update instruction requires one consumer, found 0"));
}
TEST_F(HloVerifierTest, AsyncDoneNoAsyncStart) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
p1 = u32[] parameter(1)
tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a async-done instruction needs to be "
"async-start or async-update, found tuple"));
}
TEST_F(HloVerifierTest, AsyncUpdateAndAsyncDoneNoAsyncStart) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
p1 = u32[] parameter(1)
tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(tuple), custom_call_target="foo"
ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a async-update instruction needs to be "
"async-start or async-update, found tuple"));
}
TEST_F(HloVerifierTest, AsyncOpComputationParamWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[3,2]), f32[3,2], u32[]) async-start(p0), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-start expects the async shape at index {0} to "
"match async computation parameter shape"));
}
TEST_F(HloVerifierTest, AsyncOpComputationRootWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[2,3], u32[]) async-start(p0), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-start expects the async shape at index {1} to "
"match the async computation root shape"));
}
TEST_F(HloVerifierTest, AsyncOpTupleWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3])) async-start(p0), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-start expects the async shape to be a tuple of "
"at least two elements"));
}
TEST_F(HloVerifierTest, AsyncStartOperandWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[3,2] parameter(0)
async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-start expects the shape of operand 0 to match "
"the async shape at index {0}"));
}
TEST_F(HloVerifierTest, AsyncDoneOutputWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
ROOT async-done = f32[2,3] async-done(async-start), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-done expects the shape of output to match the "
"async shape at index {1}"));
}
TEST_F(HloVerifierTest, AsyncUpdateWrongType) {
const char* const hlo_string = R"(
HloModule Module
async_computation {
p = f32[2,3] parameter(0)
ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
}
ENTRY AsyncStartAndAsyncDone {
p0 = f32[2,3] parameter(0)
async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
async-update = ((f32[3,2]), f32[3,2], u32[]) async-update(async-start), calls=async_computation
ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr(
"async-update expects the shape of operand and output to match"));
}
TEST_F(HloVerifierTestLayoutSensitive, AsyncDoneWrongGroupId) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_group_id=0, custom_call_target="foo"
async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=1, custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-done expects its operand to have the same group "
"id (1 vs 0)."));
}
TEST_F(HloVerifierTestLayoutSensitive, AsyncUpdateWrongGroupId) {
const char* const hlo_string = R"(
HloModule Module
ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=0, custom_call_target="foo"
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("async-update expects its operand to have the same "
"group id (none vs 0)."));
}
TEST_F(HloVerifierTest, IotaNonArrayResult) {
const char* const hlo_string = R"(
HloModule IotaTupleResult
ENTRY kernelEntry {
ROOT iota = () iota(), iota_dimension=24
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("does not support non-array result"));
}
TEST_F(HloVerifierTest, IotaNegativeDimension) {
const char* const hlo_string = R"(
HloModule IotaTupleResult
ENTRY kernelEntry {
ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("negative"));
}
TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
const char* const hlo_string = R"(
HloModule IotaPredResult
ENTRY kernelEntry {
ROOT iota = pred[128] iota(), iota_dimension=0
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
}
static const char* const kMapOperandComputationMismatchHlo = R"(
HloModule MapOperandComputationMismatch
Computation {
param0 = f32[] parameter(0)
constant = f32[] constant(1)
ROOT add = f32[] add(param0, constant)
}
ENTRY kernelEntry {
param = f64[] parameter(0)
ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
})";
TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
kMapOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr(
"Shape mismatch between to_apply computation parameter and operand"));
}
TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
kMapOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
static const char* const kReduceOperandComputationMismatchHlo = R"(
HloModule ReduceOperandComputationMismatch
computation {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY kernelEntry {
arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
constant = f16[] constant(0)
reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
})";
TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to f32[64]"));
}
TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
std::string ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups) {
std::vector<std::string> replica_group_strs;
replica_group_strs.reserve(replica_groups.size());
for (const auto& g : replica_groups) {
replica_group_strs.push_back(
absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
}
return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
}
int64_t ReplicaCount(const std::vector<std::vector<int64_t>>& replica_groups) {
int64_t replica_count = 0;
for (auto group : replica_groups) {
replica_count += group.size();
}
return replica_count;
}
StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
std::vector<std::vector<int64_t>> replica_groups,
std::optional<int64_t> replica_count, std::optional<int64_t> num_partitions,
absl::string_view other_attributes, absl::string_view template_str) {
HloModuleConfig config;
config.set_replica_count(
replica_count.value_or(ReplicaCount(replica_groups)));
config.set_num_partitions(num_partitions.value_or(1));
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
template_str,
{{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
{"OTHER_ATTRIBUTES", other_attributes.empty()
? ""
: absl::StrCat(",", other_attributes)}}),
config);
}
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
std::vector<std::vector<int64_t>> replica_groups,
std::optional<int64_t> replica_count = std::nullopt,
std::optional<int64_t> num_partitions = std::nullopt,
absl::string_view other_attributes = "") {
const char* kTemplate = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p = f32[128]{0} parameter(0)
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
OTHER_ATTRIBUTES
})";
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
num_partitions, other_attributes,
kTemplate);
}
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0}, {1, 3}, {2}}));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("empty replica group"));
}
TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica 0 is repeated"));
}
TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica 4 is not named"));
}
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kCrossReplica mode, replica groups should contain "
"8 replicas, but found 2"));
}
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kCrossReplica mode, replica groups should contain "
"2 replicas, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
EXPECT_THAT(
verifier().Run(module.get()).status().error_message(),
HasSubstr(
"In kCrossReplicaAndPartition mode, replica groups should contain "
"2 replicas, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
"channel_id=1, use_global_device_ids=true"));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kFlattenedID mode, replica groups should contain "
"2 flattened IDs, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
"channel_id=1, use_global_device_ids=true"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllReduceStartAndDone) {
const char* const kModuleStr = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p0 = f32[2,3] parameter(0)
start = f32[2,3] all-reduce-start(p0), to_apply=add
ROOT done = f32[2,3] all-reduce-done(start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, AllReduceStartAndDoneWrongType) {
const char* const kModuleStr = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p0 = f32[2,3] parameter(0)
start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
ROOT done = f32[2,3] all-reduce-done(start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"f32[2,3]"));
}
TEST_F(HloVerifierTest, AllReduceStartAndMultipleDone) {
const char* const kModuleStr = R"(
HloModule test
add {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT add = f32[] add(x, y)
}
ENTRY entry {
p0 = f32[2,3] parameter(0)
start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
done1 = f32[2,3] all-reduce-done(start)
ROOT done2 = f32[2,3] all-reduce-done(start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("all-reduce-start instruction requires one consumer, found 2"));
}
TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[2,3] parameter(0)
p1 = u32[] parameter(1)
tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
ROOT done = f32[2,3] all-reduce-done(tuple)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a all-reduce-done instruction "
"needs to be all-reduce-start, found tuple"));
}
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
std::vector<std::vector<int64_t>> replica_groups,
std::optional<int64_t> replica_count = std::nullopt,
std::optional<int64_t> num_partitions = std::nullopt,
absl::string_view other_attributes = "") {
const char* kTemplate = R"(
HloModule test
ENTRY entry {
p0 = f32[128]{0} parameter(0)
p1 = f32[128]{0} parameter(1)
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
OTHER_ATTRIBUTES
})";
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
num_partitions, other_attributes,
kTemplate);
}
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("cannot have an empty replica group"));
}
TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica 0 is repeated"));
}
TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica 4 is not named"));
}
TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica groups expected to be of uniform size"));
}
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kCrossPartition mode, replica groups should "
"contain 2 partitions, but found 4"));
}
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128,4]{0,1} parameter(0)
p1 = f32[128,4]{1,0} parameter(1)
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
replica_groups={{0,1}}
}
)";
HloModuleConfig config;
config.set_replica_count(2);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("HLO all-to-all has operands with different shapes"));
}
TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,1}, {0,2}, {1,0}}
}
)";
HloModuleConfig config;
config.set_replica_count(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Source 0 appears more than once"));
}
TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,2}, {1,2}, {2,0}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Target 2 appears more than once"));
}
TEST_F(HloVerifierTest, CollectivePermuteSameSourceTooManyTimes) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
replica_id = u32[] replica-id()
broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
constant.1 = u32[] constant(1000)
broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
constant.2 = s32[] constant(0)
constant.3 = s32[] constant(1)
tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{0,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Source 0 appears more than 2 times in instruction's "
"source-target pairs:"));
}
TEST_F(HloVerifierTest, CollectivePermuteSameTargetTooManyTimes) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
replica_id = u32[] replica-id()
broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
constant.1 = u32[] constant(1000)
broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
constant.2 = s32[] constant(0)
constant.3 = s32[] constant(1)
tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,3},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Target 3 appears more than 2 times in instruction's "
"source-target pairs:"));
}
TEST_F(HloVerifierTest, CollectivePermuteUnmatchingSourceTarget) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
replica_id = u32[] replica-id()
broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
constant.1 = u32[] constant(1000)
broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
constant.2 = s32[] constant(0)
constant.3 = s32[] constant(1)
tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
constant.4 = s32[] constant(2)
tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Unmatching input buffers and output buffers"));
}
TEST_F(HloVerifierTest, CollectivePermuteUnmatchingInputAndInputOffset) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
replica_id = u32[] replica-id()
broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
constant.1 = u32[] constant(1000)
broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
constant.2 = s32[] constant(0)
constant.3 = s32[] constant(1)
tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
constant.4 = s32[] constant(2)
tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (s32[],s32[],s32[]) tuple.3, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Unmatching input buffers and input offset."));
}
TEST_F(HloVerifierTest, CollectivePermuteUnmatchingOutputAndOutputOffset) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
replica_id = u32[] replica-id()
broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
constant.1 = u32[] constant(1000)
broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
constant.2 = s32[] constant(0)
constant.3 = s32[] constant(1)
tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
constant.4 = s32[] constant(2)
tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (s32[],s32[],s32[]) tuple.2), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Unmatching output buffers and output offset."));
}
TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{5,2}, {1,2}, {2,0}}
}
)";
HloModuleConfig config;
config.set_replica_count(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
const std::string error_message =
verifier().Run(module.get()).status().error_message();
EXPECT_THAT(error_message, HasSubstr("Source 5"));
EXPECT_THAT(error_message, HasSubstr("must be < 3"));
}
TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,1}, {1,2}, {2,7}}
}
)";
HloModuleConfig config;
config.set_replica_count(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
const std::string error_message =
verifier().Run(module.get()).status().error_message();
EXPECT_THAT(error_message, HasSubstr("Target 7"));
EXPECT_THAT(error_message, HasSubstr("must be < 3"));
}
TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1
}
)";
HloModuleConfig config;
config.set_num_partitions(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
const std::string error_message =
verifier().Run(module.get()).status().error_message();
EXPECT_THAT(error_message, HasSubstr("Source 5"));
EXPECT_THAT(error_message, HasSubstr("must be < 3"));
}
TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
const char* const kModuleStr = R"(
HloModule test
ENTRY entry {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1
}
)";
HloModuleConfig config;
config.set_num_partitions(3);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr, config));
const std::string error_message =
verifier().Run(module.get()).status().error_message();
EXPECT_THAT(error_message, HasSubstr("Target 7"));
EXPECT_THAT(error_message, HasSubstr("must be < 3"));
}
TEST_F(HloVerifierTest, FusionShapeVerifier) {
const char* const kModuleStr = R"(
HloModule test
fused_computation {
ROOT p0 = f32[10,10] parameter(0)
}
ENTRY entry {
p0 = f32[10,10] parameter(0)
ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Fused computation shape"));
}
TEST_F(HloVerifierTest, FusionThreadVerifier) {
const char* const kModuleStr = R"(
HloModule test
fused_computation {
ROOT p0 = f32[8,12] parameter(0)
}, thread_name="parallel_thread"
ENTRY entry {
p0 = f32[8,12] parameter(0)
ROOT out = f32[8,12] fusion(p0), kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("expects parent computation thread name same as called "
"computation's thread name"));
}
TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
const char* const kModuleStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}, thread_name="parallel_thread"
fused_computation {
p0 = f32[8,12] parameter(0)
p1 = f32[8,12] parameter(1)
crs0 = f32[8,12] all-reduce(p1), replica_groups={}, to_apply=add
ROOT result = add(p0, crs0)
}
ENTRY entry {
p0 = f32[8,12] parameter(0)
p1 = f32[8,12] parameter(1)
ROOT out = f32[8,12] fusion(p0, p1), kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(
verifier().Run(module.get()).status().error_message(),
HasSubstr("Nested computations expects same computation's thread name"));
}
TEST_F(HloVerifierTest, AllReduceVerifier) {
const char* const kModuleStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
input = f32[8,12]{0,1} parameter(0)
crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
constrain_layout=true
ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(
verifier().Run(module.get()).status().error_message(),
HasSubstr("mix of layout constrained and unconstrained AllReduce"));
}
TEST_F(HloVerifierTest, ChannelVerifier) {
const char* const kModuleStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[8,12] parameter(0)
%token0 = token[] after-all()
%send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
%send-done = token[] send-done(%send), channel_id=1
%crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
channel_id=1
ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("used for different types of channel instructions"));
}
TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
const char* const kModuleStr = R"(
HloModule test
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[8,12] parameter(0)
%permute = f32[8,12] collective-permute(%input),
source_target_pairs={{0,1},{1,0}}, channel_id=1
%crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
channel_id=1
ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("used for different types of channel instructions"));
}
TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndDoneWrongType {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"(f32[2,3], f32[2,3], u32[], u32[])"));
}
TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteStartAndMultipleDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("collective-permute-start instruction requires one consumer, "
"found 2"));
}
TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
const char* const kModuleStr = R"(
HloModule Module
ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
p1 = f32[2,3]{1,0:S(1)} parameter(1)
p2 = u32[] parameter(2)
p3 = u32[] parameter(3)
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(kModuleStr));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("The operand of a collective-permute-done instruction "
"needs to be collective-permute-start, found tuple"));
}
TEST_F(HloVerifierTest, ComparisonTypeFloat) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = f32[] parameter(0)
ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
}
TEST_F(HloVerifierTest, ComparisonTypeSigned) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = s32[] parameter(0)
ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type SIGNED"));
}
TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = u32[] parameter(0)
ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type UNSIGNED"));
}
TEST_F(HloVerifierTest, ComparisonTypePred) {
const char* const hlo_string = R"(
HloModule Module
ENTRY RngOperandElementTypesNotMatch {
p0 = pred[] parameter(0)
ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Expected comparison type UNSIGNED"));
}
TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
use_global_device_ids=true, to_apply=add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("Replica groups must be specified in flattened-id mode"));
}
TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
use_global_device_ids=true, to_apply=add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr(
"Invalid combination of has_channel_id and use_global_device_ids"));
}
TEST_F(HloVerifierTest, ReduceScatterInvalidOutputSize0) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} reduce-scatter(input), replica_groups={{0,1}},
to_apply=add, dimensions={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("shard_count = 1, subgroup_size = 2"));
}
TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}},
to_apply=add, dimensions={1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
}
TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}, {2,3,4}},
to_apply=add, dimensions={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Replica groups expected to be of uniform size"));
}
TEST_F(HloVerifierTestLayoutFusion, DynamicUpdateSliceWithMemorySpace) {
const char* const hlo_string = R"(
HloModule fusion, is_scheduled=true
fused_computation {
%parameter.0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
%parameter.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
%c = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)} copy(parameter.1)
%constant.1 = s32[] constant(0)
ROOT %dynamic-update-slice.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)}
dynamic-update-slice(%c, %parameter.0, %constant.1, %constant.1,
%constant.1, %constant.1, %constant.1)
}
ENTRY entry (parameter.0: bf16[1,8,1,8,320], parameter.1: bf16[1,8,6,8,320]) -> bf16[1,8,6,8,320]{
%p0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
%p1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
ROOT out = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} fusion(p0, p1), kind=kLoop, calls=fused_computation
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_TRUE(status.ok());
}
} // namespace
} // namespace xla