blob: 9f3708fc12c7849367f3df13723d1fe3db78bb0f [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace spmd {
namespace {
using ::testing::_;
using ::testing::AllOf;
namespace op = xla::testing::opcode_matchers;
class SpmdPartitioningTest : public HloTestBase {
public:
StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
const char* hlo_module, int64 num_devices,
bool conv_halo_exchange_always_on_lhs = true) {
// Some tests (BackpropFilter convs) set this flag false to test two
// different paths of the implementation.
SpmdPartitionerOptions options;
options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
options.allow_module_signature_change = true;
auto collective_ops_creator =
GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
// Do not use all-gather for pattern-matching purpose, as the partitioner
// might create reshape/transposes around it.
collective_ops_creator.create_cross_partition_all_gather = nullptr;
TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
hlo_module, GetModuleConfigForTest()));
HloPassPipeline pass("spmd-partitioning");
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
collective_ops_creator);
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
}
};
TEST_F(SpmdPartitioningTest, InvalidSharding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
token0 = token[] after-all(), sharding={maximal device=0}
infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
sharding={{devices=[2,1]0,1}, {maximal device=0}}
ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
sharding={maximal device=0}
})";
auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
EXPECT_FALSE(module_status.status().ok());
EXPECT_THAT(module_status.status().ToString(),
::testing::HasSubstr(
"only supports tile sharding that includes all partitions"));
}
TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
sharding={maximal device=0}
ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
op::Select(op::Broadcast(op::Compare()),
op::Constant(), op::Broadcast()))),
op::Shape("s32[2,3]")));
}
TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
sharding={maximal device=0}
ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
HloInstruction* root = module->entry_computation()->root_instruction();
VLOG(1) << module->ToString();
EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select(
op::Broadcast(op::Compare()),
op::Constant(), op::Broadcast()))),
op::Shape("s32[2,3]"))));
}
TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
sharding={maximal device=0}
ROOT %copy = s32[2,3]{1,0} copy(%constant),
sharding={devices=[2,1]1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(
op::Copy(op::DynamicSlice(
op::AllReduce(op::Select(
op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
op::Constant(), op::Broadcast())),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
op::Constant())),
op::Constant())),
op::Shape("s32[1,3]")));
}
TEST_F(SpmdPartitioningTest, TiledToReplicated) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
sharding={devices=[2,1]0,1}
ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
op::Copy(op::AllReduce(AllOf(
op::DynamicUpdateSlice(
op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
op::Constant())),
op::Constant()),
op::Shape("s32[2,3]")))));
}
TEST_F(SpmdPartitioningTest, TiledToSingleDevice) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
sharding={devices=[2,1]0,1}
ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
op::Copy(op::Copy(op::AllReduce(AllOf(
op::DynamicUpdateSlice(
op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
op::Constant())),
op::Constant()),
op::Shape("s32[2,3]"))))));
}
TEST_F(SpmdPartitioningTest, TiledToTiledEven) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1}
ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf(
op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))),
op::Shape("s32[8,1]")));
}
TEST_F(SpmdPartitioningTest, TiledToTiledUneven) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1}
ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll(
op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]")))))))))));
}
TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param.0 = (f32[2,3]{1,0}, u32[]) parameter(0),
sharding={{maximal device=1}, {maximal device=1}}
%gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0,
sharding={maximal device=0}
%gte.1 = u32[] get-tuple-element(%param.0), index=1,
sharding={maximal device=0}
ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1),
sharding={{maximal device=0},{maximal device=0}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
ASSERT_THAT(root, op::Tuple());
EXPECT_THAT(root->operand(0),
op::Copy(op::AllReduce(op::Select(
op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
op::GetTupleElement(op::Parameter()), op::Broadcast()))));
EXPECT_THAT(root->operand(1),
op::Copy(op::AllReduce(op::Select(
op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
op::GetTupleElement(op::Parameter()), op::Broadcast()))));
}
TEST_F(SpmdPartitioningTest, GetTupleElementTiled) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0),
sharding={{replicated}, {replicated}}
gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0,
sharding={devices=[2,1]0,1}
gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1,
sharding={devices=[2,1]0,1}
ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1),
sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
ASSERT_THAT(root, op::Tuple());
auto offset = op::Reshape(
op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
EXPECT_THAT(root->operand(0),
op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
op::Constant()));
EXPECT_THAT(root->operand(1),
op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
op::Constant()));
}
TEST_F(SpmdPartitioningTest, TiledInfeed) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
token0 = token[] after-all(), sharding={maximal device=0}
infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
sharding={{devices=[2,1]0,1}, {maximal device=0}}
ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
sharding={maximal device=0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, op::Copy(op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(),
op::GetTupleElement(
AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
op::Constant())),
op::Constant()))));
}
TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
token0 = token[] after-all(), sharding={maximal device=0}
infeed = (f32[9,2]{1,0}, token[]) infeed(token0),
sharding={{devices=[2,1]0,1}, {maximal device=0}}
ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0,
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional(
op::Convert(op::PartitionId()),
op::AfterAll(), op::AfterAll()))));
EXPECT_THAT(
root->operand(0)->called_computations()[0]->root_instruction(),
AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter())));
auto second_infeed =
AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter()));
EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
AllOf(op::Shape("(f32[5,2], token[])"),
op::Tuple(op::Pad(op::GetTupleElement(second_infeed),
op::Constant()),
op::GetTupleElement(second_infeed))));
}
TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
token0 = token[] after-all(), sharding={maximal device=0}
infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}}
ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
index=0, sharding={{devices=[2,1]0,1}, {replicated}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"),
op::GetTupleElement(op::Conditional(
op::Convert(op::PartitionId()), op::AfterAll(),
op::AfterAll()))));
EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
op::Infeed(op::Parameter())));
auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"),
op::Infeed(op::Parameter()));
EXPECT_THAT(
root->operand(0)->called_computations()[1]->root_instruction(),
AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
op::Tuple(op::Tuple(op::Pad(op::GetTupleElement(
op::GetTupleElement(second_infeed)),
op::Constant()),
op::GetTupleElement(
op::GetTupleElement(second_infeed))),
op::GetTupleElement(second_infeed))));
}
TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
sharding={devices=[2,1]0,1}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1},
to_apply=sum, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
op::AllReduce(op::Reduce(
op::Select(
op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())),
op::Broadcast(op::Constant())),
AllOf(op::Shape("f32[2,3]{1,0}"),
op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
op::Reshape(), op::Constant())),
op::Broadcast(op::Constant())),
op::Constant())));
}
TEST_F(SpmdPartitioningTest, TiledElementwise) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
sharding={devices=[2,1]0,1}
constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}),
sharding={replicated}
multiply = f32[3,3]{1,0} multiply(constant, constant.1),
sharding={devices=[2,1]0,1}
ROOT add = f32[3,3]{1,0} add(multiply, constant.1),
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(
op::Shape("f32[2,3]{1,0}"),
op::Add(op::Multiply(
op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
op::Reshape(), op::Constant()),
op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
op::Reshape(), op::Constant())),
op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
op::Reshape(), op::Constant()))));
}
TEST_F(SpmdPartitioningTest, TiledAllReduce) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1}
ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum,
replica_groups={}, sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0))));
}
TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
sharding={replicated}
ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"),
op::Broadcast(op::Constant())));
}
TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
sharding={replicated}
ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
op::Broadcast(op::DynamicSlice(
op::Constant(), op::Reshape(), op::Constant()))));
}
TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
sharding={replicated}
ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
sharding={devices=[2,2,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Shape("f32[2,2,3]{2,1,0}"),
op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"),
op::DynamicSlice(op::Constant(), op::Reshape(),
op::Constant())))));
}
TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}),
sharding={devices=[2,1]0,1}
ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
op::Broadcast(op::DynamicSlice(
op::Constant(), op::Reshape(), op::Constant()))));
}
TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
token.0 = token[] after-all()
data = f32[1024]{0} parameter(0), sharding={maximal device=0}
outfeed = token[] outfeed(data, token.0), sharding={maximal device=0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("token[]"),
op::Conditional(
op::Compare(op::PartitionId(), op::Constant()),
op::Tuple(op::Parameter(0), op::AfterAll()),
op::Tuple(op::Parameter(0), op::AfterAll()))));
HloInstruction* root_b0 = root->branch_computation(0)->root_instruction();
EXPECT_THAT(root_b0,
AllOf(op::Shape("token[]"),
op::Outfeed(op::GetTupleElement(op::Parameter(), 0),
op::GetTupleElement(op::Parameter(), 1))));
HloInstruction* root_b1 = root->branch_computation(1)->root_instruction();
EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll()));
}
TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
sharding={replicated}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1),
window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum,
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Shape("f32[2,2]{1,0}"),
op::ReduceWindow(
op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"),
op::Pad(op::Constant(), op::Constant())),
op::Multiply(op::Reshape(), op::Constant()),
op::Constant()),
op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
sharding={devices=[2,1]0,1}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1),
window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum,
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto sharded_input =
op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto pre_masking = op::DynamicSlice(
AllOf(
op::Shape("f32[6,2]{1,0}"),
op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
op::Reshape(), op::Constant());
auto index_in_padded = op::Add(
op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
auto masked =
op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
pre_masking, op::Broadcast(op::Constant()));
EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
sharding={devices=[5,1]0,1,2,3,4}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/5));
VLOG(1) << module->ToString();
auto halo0 = AllOf(op::Shape("f32[1,2]"),
op::CollectivePermute(op::Slice(op::Parameter(0))));
auto halo1 =
AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
auto pre_mask =
AllOf(op::Shape("f32[4,2]"),
op::Slice(AllOf(op::Shape("f32[5,2]"),
op::Concatenate(halo0, halo1, op::Parameter(0)))));
auto masked =
op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
op::Broadcast(op::Constant())),
pre_mask, op::Broadcast(op::Constant()));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
constant = f32[9,2]{1,0} constant(
{{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}),
sharding={devices=[3,1]0,1,2}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1),
window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum,
sharding={devices=[3,1]0,1,2}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/3));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto sharded_input =
op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto pre_masking = op::DynamicSlice(
AllOf(
op::Shape("f32[7,2]{1,0}"),
op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
op::Reshape(), op::Constant());
auto index_in_padded = op::Add(
op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
auto masked = op::Select(
op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
pre_masking, op::Broadcast(op::Constant()));
EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}),
sharding={devices=[2,1]0,1}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1),
window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum,
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto sharded_input =
op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto pre_masking = AllOf(
op::Shape("f32[5,2]{1,0}"),
op::DynamicSlice(
AllOf(op::Shape("f32[6,2]{1,0}"),
op::Pad(op::Concatenate(left_halo, sharded_input, right_halo),
op::Constant())),
op::Reshape(), op::Constant()));
auto index_in_padded = op::Add(
op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
auto masked = op::Select(
op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
pre_masking, op::Broadcast(op::Constant()));
EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
op::ReduceWindow(masked, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
token0 = token[] after-all(), sharding={maximal device=0}
infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0),
sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}}
infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0,
sharding={devices=[2,2,1,1]0,1,2,3}
constant = f32[] constant(0), sharding={replicated}
ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant),
window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum,
sharding={devices=[2,2,1,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"),
op::GetTupleElement(op::Infeed()));
auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
op::CollectivePermute(op::Slice(sharded_input)));
auto dim0_pre_masking = op::DynamicSlice(
AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"),
op::Pad(
op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo),
op::Constant())),
op::Reshape(), op::Constant(), op::Constant(), op::Constant());
auto dim0_index_in_padded = op::Add(
op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
auto dim0_masked = op::Select(
op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())),
op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))),
dim0_pre_masking, op::Broadcast(op::Constant()));
auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked);
auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
op::CollectivePermute(op::Slice(dim0_resharded)));
auto dim1_right_halo =
AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
op::CollectivePermute(op::Slice(dim0_resharded)));
auto dim1_pre_masking = op::DynamicSlice(
AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"),
op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded,
dim1_right_halo),
op::Constant())),
op::Constant(), op::Reshape(), op::Constant(), op::Constant());
auto dim1_index_in_padded = op::Add(
op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
auto dim1_masked = op::Select(
op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())),
op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))),
dim1_pre_masking, op::Broadcast(op::Constant()));
auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked);
EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"),
op::ReduceWindow(dim1_resharded, op::Constant())));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,224,224,3] parameter(0)
%lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
sharding={devices=[1,2,1,1]0,1}
%rhs = f32[7,7,3,64] parameter(1)
%rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
sharding={replicated}
ROOT %conv = f32[128,112,112,64] convolution(
f32[128,224,224,3] %lhs.copy,
f32[7,7,3,64] %rhs.copy),
window={size=7x7 stride=2x2 pad=3_3x3_3},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,112,224,3]"));
auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,3,224,3]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,2,224,3]"));
EXPECT_THAT(root,
AllOf(op::Convolution(
op::Select(op::And(),
op::Concatenate(left_halo, lhs, right_halo),
op::Broadcast()),
rhs),
op::Shape("f32[128,56,112,64]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,224,224,3] parameter(0)
%lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
sharding={devices=[2,1,1,1]0,1}
%rhs = f32[7,7,3,64] parameter(1)
%rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
sharding={replicated}
ROOT %conv = f32[128,112,112,64] convolution(
f32[128,224,224,3] %lhs.copy,
f32[7,7,3,64] %rhs.copy),
window={size=7x7 stride=2x2 pad=3_3x3_3},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
op::Constant(), op::Constant())),
op::Shape("f32[64,224,224,3]"));
auto all_to_all =
AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]"));
auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)),
op::Shape("f32[128,112,224,3]"));
auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
op::Shape("f32[128,3,224,3]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
op::Shape("f32[128,2,224,3]"));
EXPECT_THAT(
root,
AllOf(op::Convolution(
op::Select(op::And(),
op::Concatenate(left_halo, reshard_lhs, right_halo),
op::Broadcast()),
rhs),
op::Shape("f32[128,56,112,64]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[224,224,3,128] parameter(0)
%lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1}
%rhs = f32[7,7,3,64] parameter(1)
%rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated}
ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy),
window={size=7x7 stride=2x2 pad=3_3x3_3},
dim_labels=01fb_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
op::Constant(), op::Constant())),
op::Shape("f32[112,224,3,128]"));
auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[3,224,3,128]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[2,224,3,128]"));
EXPECT_THAT(root,
AllOf(op::Convolution(
op::Select(op::And(),
op::Concatenate(left_halo, lhs, right_halo),
op::Broadcast()),
rhs),
op::Shape("f32[128,56,112,64]")));
}
// (stride * per_shard_window_count) % dilation == 0
TEST_F(SpmdPartitioningTest,
ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,7,7,512] parameter(0)
%lhs.copy = f32[128,7,7,512] copy(%lhs),
sharding={devices=[1,2,1,1]0,1}
%rhs = f32[3,3,512,512] parameter(1)
%rhs.copy = f32[3,3,512,512] copy(%rhs),
sharding={replicated}
ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy),
window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
// There is no halo exchange, and because the last element in the shard is not
// needed (stride == 4), the LHS will be just a slice.
auto sliced_lhs =
AllOf(op::Slice(op::Copy(op::DynamicSlice(
op::Pad(op::Parameter(), op::Constant()), op::Constant(),
op::Reshape(), op::Constant(), op::Constant()))),
op::Shape("f32[128,3,7,512]"));
auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs),
op::Shape("f32[128,2,4,512]")));
EXPECT_EQ(root->window().dimensions(0).padding_low(), 1);
EXPECT_EQ(root->window().dimensions(0).padding_high(), 1);
}
// (stride * per_shard_window_count) % dilation != 0 but stride == 1
TEST_F(SpmdPartitioningTest,
ConvolutionBaseDilationStride1LhsTiledRhsReplicated) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,7,7,512] parameter(0)
%lhs.copy = f32[128,7,7,512] copy(%lhs),
sharding={devices=[1,2,1,1]0,1}
%rhs = f32[3,3,512,512] parameter(1)
%rhs.copy = f32[3,3,512,512] copy(%rhs),
sharding={replicated}
ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy),
window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1},
dim_labels=b01f_01io->b01f,
sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(
op::Pad(op::Parameter(), op::Constant()), op::Constant(),
op::Reshape(), op::Constant(), op::Constant())),
op::Shape("f32[128,4,7,512]"));
auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,1,7,512]"));
auto start_window = op::Multiply(op::Reshape(), op::Constant());
auto start_input_element = op::Divide(start_window, op::Constant());
auto dynamic_offset_for_padded_concat = op::Subtract(
op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
start_input_element));
auto pre_masking =
AllOf(op::Shape("f32[128,5,7,512]"),
op::DynamicSlice(
AllOf(op::Shape("f32[128,6,7,512]"),
op::Pad(op::Concatenate(left_halo, lhs), op::Constant())),
op::Constant(), dynamic_offset_for_padded_concat,
op::Constant(), op::Constant()));
auto masked = op::Select(
op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)),
op::Broadcast(op::Constant())),
pre_masking, op::Broadcast(op::Constant()));
auto dynamic_offset_on_output = op::Subtract(
start_window, op::Multiply(start_input_element, op::Constant()));
EXPECT_THAT(root,
AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs),
op::Shape("f32[128,8,14,512]")),
op::Constant(), dynamic_offset_on_output,
op::Constant(), op::Constant()),
op::Shape("f32[128,7,14,512]")));
EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1);
EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
}
TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) {
const char* const hlo_string = R"(
HloModule module
ge {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT compare = pred[] compare(a, b), direction=GE
}
sum {
c = f32[] parameter(0)
d = f32[] parameter(1)
ROOT add = f32[] add(c, d)
}
ENTRY entry {
%param = f32[11,4]{1,0} parameter(0)
%param.copy = f32[11,4] copy(%param),
sharding={devices=[4,1]0,1,2,3}
constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
sharding={devices=[4,1]0,1,2,3}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto source =
AllOf(op::Shape("f32[1,2]{1,0}"),
op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
auto masked_data = AllOf(
op::Shape("f32[3,4]{1,0}"),
op::Select(
op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
op::Reshape(), op::Constant()))),
op::Broadcast(op::Constant())),
op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Reshape(), op::Constant())),
op::Broadcast(op::Constant())));
EXPECT_THAT(root,
AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
op::Shape("f32[3,4]{1,0}")));
EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
}
TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) {
const char* const hlo_string = R"(
HloModule module
ge {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT compare = pred[] compare(a, b), direction=GE
}
sum {
c = f32[] parameter(0)
d = f32[] parameter(1)
ROOT add = f32[] add(c, d)
}
ENTRY entry {
%param = f32[11,4]{1,0} parameter(0)
%param.copy = f32[11,4] copy(%param),
sharding={devices=[1,4]0,1,2,3}
constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
sharding={devices=[4,1]0,1,2,3}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto source =
AllOf(op::Shape("f32[1,2]{1,0}"),
op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
auto operand = AllOf(op::Copy(op::DynamicSlice(
op::Parameter(0), op::Constant(), op::Reshape())),
op::Shape("f32[11,1]"));
auto reshard_operand = op::Reshape(op::Transpose(
op::AllToAll(op::Reshape(op::Pad(operand, op::Constant())))));
auto masked_data = AllOf(
op::Shape("f32[3,4]{1,0}"),
op::Select(
op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
op::Reshape(), op::Constant()))),
op::Broadcast(op::Constant())),
reshard_operand, op::Broadcast(op::Constant())));
EXPECT_THAT(root,
AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
op::Shape("f32[3,4]{1,0}")));
EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
}
TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) {
const char* const hlo_string = R"(
HloModule module
ge {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT compare = pred[] compare(a, b), direction=GE
}
sum {
c = f32[] parameter(0)
d = f32[] parameter(1)
ROOT add = f32[] add(c, d)
}
ENTRY entry {
%param = f32[11,4]{1,0} parameter(0)
%param.copy = f32[11,4] copy(%param),
sharding={devices=[4,1]0,1,2,3}
constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}),
sharding={devices=[4,1]0,1,2,3}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0},
select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto source_shard =
AllOf(op::Shape("f32[2,2]{1,0}"),
op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant()));
// Max halo size is the same as the shard size, so slice is not needed.
auto source_left_halo = op::CollectivePermute(source_shard);
auto required_source_shard_start =
op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
auto source_with_halo = op::DynamicSlice(
AllOf(op::Shape("f32[5,2]{1,0}"),
op::Pad(op::Concatenate(source_left_halo, source_shard),
op::Constant())),
op::Subtract(op::Constant(),
op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
required_source_shard_start)),
op::Constant());
auto masked_source_with_halo = AllOf(
AllOf(op::Shape("f32[3,2]{1,0}")),
op::Select(
op::Compare(
op::Add(op::Iota(), op::Broadcast(required_source_shard_start)),
op::Broadcast(op::Constant())),
source_with_halo, op::Broadcast(op::Constant())));
auto data_shard =
AllOf(op::Shape("f32[3,4]{1,0}"),
op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Reshape(), op::Constant())));
auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
op::CollectivePermute(op::Slice(data_shard)));
auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
op::CollectivePermute(op::Slice(data_shard)));
auto required_data_start_on_padded =
op::Multiply(required_source_shard_start, op::Constant());
auto left_halo_size = op::Subtract(
op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()),
required_data_start_on_padded);
auto data_with_halo =
AllOf(op::Shape("f32[7,4]{1,0}"),
op::DynamicSlice(
AllOf(op::Shape("f32[8,4]{1,0}"),
op::Pad(op::Concatenate(data_left_halo, data_shard,
data_right_halo),
op::Constant())),
op::Subtract(op::Constant(), left_halo_size), op::Constant()));
auto index_on_padded =
op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded));
auto masked_data_with_halo = op::Select(
op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())),
op::Compare(index_on_padded, op::Broadcast(op::Constant()))),
data_with_halo, op::Broadcast(op::Constant()));
EXPECT_THAT(
root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo,
masked_source_with_halo,
op::Constant()),
left_halo_size, op::Constant()),
op::Shape("f32[3,4]{1,0}")));
EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0);
EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,56,56,64] parameter(0)
%lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,56,56,256] parameter(1)
%rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,64]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,256]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
op::Shape("f32[1,1,64,256]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
%rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto lhs_masked =
AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
auto rhs_left_padded = op::Slice(op::Concatenate(
op::CollectivePermute(op::Slice(op::Parameter(1))), op::Parameter(1)));
auto rhs_masked =
AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
op::Shape("f32[1,64,256]")));
}
TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,56,56,64] parameter(0)
%lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,56,56,256] parameter(1)
%rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,64]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
op::Constant(), op::Constant())),
op::Shape("f32[64,56,56,256]"));
auto all_to_all =
AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]"));
auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all)));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)),
op::Shape("f32[1,1,64,256]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,56,56,512] parameter(0)
%lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,28,28,64] parameter(1)
%rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy),
window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2},
dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,512]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
op::Constant(), op::Constant())),
op::Shape("f32[64,28,28,64]"));
auto all_to_all =
AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]"));
auto reshard = op::Reshape(op::Transpose(all_to_all));
EXPECT_THAT(root,
AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)),
op::Shape("f32[1,1,512,64]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,28,28,128] parameter(0)
%lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,28,28,64] parameter(1)
%rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/false));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,14,28,128]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,14,28,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
op::Shape("f32[32,1,28,64]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
op::Shape("f32[32,1,28,64]"));
EXPECT_THAT(root,
AllOf(op::AllReduce(op::Convolution(
lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
op::Shape("f32[32,16,28,64]")))),
op::Shape("f32[3,3,128,64]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,224,224,3] parameter(0)
%lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,112,112,64] parameter(1)
%rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/false));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,112,224,3]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,56,112,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
op::Shape("f32[128,2,112,64]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
op::Shape("f32[128,2,112,64]"));
EXPECT_THAT(root,
AllOf(op::AllReduce(op::Convolution(
lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
op::Shape("f32[128,60,112,64]")))),
op::Shape("f32[7,7,3,64]")));
}
TEST_F(SpmdPartitioningTest,
ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,56,56,256] parameter(0)
%lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,28,28,512] parameter(1)
%rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/false));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,256]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,14,28,512]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
op::Shape("f32[1,1,256,512]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,14,14,512] parameter(0)
%lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,7,7,512] parameter(1)
%rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/false));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,7,14,512]"));
auto rhs = AllOf(
op::Select(op::Compare(),
op::Copy(op::DynamicSlice(
op::Pad(op::Parameter(), op::Constant()), op::Constant(),
op::Reshape(), op::Constant(), op::Constant())),
op::Broadcast()),
op::Shape("f32[128,4,7,512]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
op::Shape("f32[128,1,7,512]"));
EXPECT_THAT(root,
AllOf(op::AllReduce(op::Convolution(
AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()),
op::Constant(), op::Subtract(),
op::Constant(), op::Constant()),
op::Shape("f32[128,10,14,512]")),
AllOf(op::Concatenate(left_halo, rhs),
op::Shape("f32[128,5,7,512]")))),
op::Shape("f32[3,3,512,512]")));
}
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,28,28,128] parameter(0)
%lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,28,28,64] parameter(1)
%rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,14,28,128]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,14,28,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[32,1,28,128]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[32,1,28,128]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
AllOf(op::Concatenate(left_halo, lhs, right_halo),
op::Shape("f32[32,16,28,128]")),
rhs)),
op::Shape("f32[3,3,128,64]")));
}
TEST_F(SpmdPartitioningTest,
ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,224,224,3] parameter(0)
%lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,112,112,64] parameter(1)
%rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,112,224,3]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,56,112,64]"));
auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,3,224,3]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,2,224,3]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
AllOf(op::Concatenate(left_halo, lhs, right_halo),
op::Shape("f32[128,117,224,3]")),
rhs)),
op::Shape("f32[7,7,3,64]")));
}
TEST_F(SpmdPartitioningTest,
ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,56,56,256] parameter(0)
%lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,28,28,512] parameter(1)
%rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,28,56,256]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,14,28,512]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)),
op::Shape("f32[1,1,256,512]")));
}
TEST_F(SpmdPartitioningTest,
ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,14,14,512] parameter(0)
%lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[128,7,7,512] parameter(1)
%rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[128,7,14,512]"));
auto rhs = AllOf(
op::Select(op::Compare(),
op::Copy(op::DynamicSlice(
op::Pad(op::Parameter(), op::Constant()), op::Constant(),
op::Reshape(), op::Constant(), op::Constant())),
op::Broadcast()),
op::Shape("f32[128,4,7,512]"));
auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
op::Shape("f32[128,1,14,512]"));
EXPECT_THAT(
root, AllOf(op::AllReduce(op::Convolution(
AllOf(op::DynamicSlice(
AllOf(op::Pad(op::Concatenate(lhs, right_halo),
op::Constant()),
op::Shape("f32[128,10,14,512]")),
op::Constant(), op::Reshape(), op::Constant(),
op::Constant()),
op::Shape("f32[128,9,14,512]")),
rhs)),
op::Shape("f32[3,3,512,512]")));
}
TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[14,257] parameter(0)
%param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1}
%param1 = f32[14,116] parameter(1)
%param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1}
ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
dimensions={1}, sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant())),
op::Shape("f32[7,257]"));
auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant())),
op::Shape("f32[7,116]"));
EXPECT_THAT(root,
AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]")));
}
TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[14,257] parameter(0)
%param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1}
%param1 = f32[14,116] parameter(1)
%param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1}
ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
dimensions={1}, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 =
AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Constant(), op::Reshape())),
op::Shape("f32[14,129]"));
auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
op::Reshape())),
op::Shape("f32[14,58]"));
EXPECT_THAT(root, AllOf(op::DynamicSlice(
AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::DynamicUpdateSlice(
op::Broadcast(), param0,
op::Constant(), op::Multiply()),
param1, op::Constant(), op::Add())),
op::Shape("f32[14,374]")),
op::Constant(), op::Multiply()),
op::Shape("f32[14,187]")));
}
TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
%const = f32[] constant(0)
ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0,
sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()),
op::Shape("f32[128,17,129]")));
}
TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
%const = f32[] constant(0)
ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2,
sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]"));
auto after_halo_exchange =
AllOf(op::Shape("f32[14,130]"),
op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
auto pad = AllOf(op::Shape("f32[14,131]"),
op::Pad(after_halo_exchange, op::Constant()));
EXPECT_THAT(root, op::DynamicSlice(pad, op::Constant(), _));
}
TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[7] parameter(0), sharding={devices=[2]0,1}
%param1 = f32[] parameter(1), sharding={replicated}
ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2,
sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]"));
auto after_halo_exchange =
AllOf(op::Shape("f32[4]"),
op::DynamicSlice(
AllOf(op::Shape("f32[5]"),
op::Concatenate(op::CollectivePermute(op::Slice(param0)),
param0)),
_));
auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
EXPECT_THAT(root, op::DynamicSlice(pad, _));
}
TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[128,14,257] parameter(0)
%param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
ROOT %slice = f32[128,11,257] slice(%param0.copy),
slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Constant(), op::Constant(), op::Reshape())),
op::Shape("f32[128,14,129]"));
EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
}
TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[128,14,257] parameter(0)
%param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
ROOT %slice = f32[63,14,251] slice(%param0.copy),
slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Constant(), op::Constant(), op::Reshape())),
op::Shape("f32[128,14,129]"));
EXPECT_THAT(
root,
AllOf(op::Slice(AllOf(
op::DynamicSlice(
AllOf(op::Concatenate(
param0,
AllOf(op::CollectivePermute(op::Slice(param0)),
op::Shape("f32[128,14,2]"))),
op::Shape("f32[128,14,131]")),
op::Constant(), op::Constant(), op::Add()),
op::Shape("f32[128,14,126]"))),
op::Shape("f32[63,14,126]")));
}
TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ge {
p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated}
bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
constant = s32[]{:T(256)} constant(0), sharding={replicated}
compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated}
constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated}
bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated}
bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated}
select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated}
p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated}
bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated}
bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated}
bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated}
select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated}
compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated}
compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated}
compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated}
p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated}
p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated}
compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated}
ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated}
}
ENTRY entry {
%param0 = f32[128,14,257] parameter(0)
%param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1}
%param1 = s32[128,14,257] parameter(1)
%param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1}
ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)})
sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true,
to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 =
AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("f32[128,7,257]"));
auto param1 =
AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("s32[128,7,257]"));
EXPECT_THAT(root, AllOf(op::Sort(param0, param1),
op::Shape("(f32[128,7,257], s32[128,7,257])")));
}
TEST_F(SpmdPartitioningTest, PartitionCustomCall) {
const char* const hlo_string = R"(
HloModule cluster_2013453984438090939__.47
ENTRY %cluster_2013453984438090939__.47
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK"
%get-tuple-element = bf16[2,2000]{1,0}
get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
index=0, sharding={replicated}
%get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
%get-tuple-element.1), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto custom_call = FindInstruction(module.get(), "custom-call.1");
EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832);
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
s32[], p.1.rhs.12: s32[]) -> pred[] {
%p.1.lhs.11 = s32[] parameter(2)
%p.1.rhs.12 = s32[] parameter(3)
%p.0.lhs.9 = bf16[] parameter(0)
%convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
%bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
%constant.20 = s32[] constant(0)
%compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
direction=LT
%constant.15 = u32[] constant(2147483647)
%bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
%subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
%bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
%select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
%bitcast-convert.16)
%p.0.rhs.10 = bf16[] parameter(1)
%convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
%bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
%constant.28 = s32[] constant(0)
%compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
direction=LT
%constant.23 = u32[] constant(2147483647)
%bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
%subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
%bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
%select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
%bitcast-convert.24)
ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
direction=GT
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
auto final_sort = FindInstruction(module.get(), "sort.1");
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
auto final_sort = FindInstruction(module.get(), "sort.1");
EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%arg_tuple.2 = s32[2,209664] parameter(1)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry
(arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[2,2000] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[2,2000] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:2], [0:2000]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
tuple(bf16[2,2000] %slice.34, s32[2,2000]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
const char* const hlo_string = R"(
HloModule module
%compare-greater-than.8 (p.0.lhs.2566: bf16[],
p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
%p.0.lhs.2566 = bf16[] parameter(0)
%convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
%bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
%constant.285 = s32[] constant(0)
%compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
direction=LT
%constant.286 = u32[] constant(2147483647)
%bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
%subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
%bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
%select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
s32[] %bitcast-convert.48)
%p.0.rhs.2567 = bf16[] parameter(1)
%convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
%bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
%compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
direction=LT
%bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
%subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
%bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
%select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
s32[] %bitcast-convert.51)
%compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
%compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
%compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
direction=EQ
%p.1.lhs.2586 = s32[] parameter(2)
%p.1.rhs.2587 = s32[] parameter(3)
%compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
direction=LT
ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
pred[] %compare.86)
}
ENTRY entry {
%arg_tuple.1 = bf16[2,209664] parameter(0)
%copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
%iota.7 = s32[2,209664] iota(), iota_dimension=1,
metadata={op_type="TopKV2" op_name="TopKV2"}
%sort.32 = (bf16[2,209664], s32[2,209664])
sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.33 = bf16[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.34 = bf16[1,209664] slice(bf16[2,209664]
%get-tuple-element.33), slice={[0:1], [0:209664]},
metadata={op_type="TopKV2" op_name="TopKV2"}
%get-tuple-element.35 = s32[2,209664]
get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
%slice.36 = s32[1,209664] slice(s32[2,209664]
%get-tuple-element.35), slice={[0:1], [0:209664]},
metadata={op_type="TopKV2" op_name="TopKV2"}
ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
tuple(bf16[1,209664] %slice.34, s32[1,209664]
%slice.36), sharding={{replicated}, {replicated}},
metadata={op_name="XLA_Retvals"}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
}
TEST_F(SpmdPartitioningTest, ShardableTranspose) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[16,38,38,4] parameter(0)
%param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[16,19,38,4]"));
EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
}
TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[16,38,38,4] parameter(0)
%param0.copy = f32[16,38,38,4] copy(%param0),
sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[4,19,38,4]"));
EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]")));
}
TEST_F(SpmdPartitioningTest, NonShardableTranspose) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[16,38,38,4] parameter(0)
%param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
op::Shape("f32[16,38,38,2]"));
EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
}
TEST_F(SpmdPartitioningTest, ShardableReshape) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[38,38,324] parameter(0)
%param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1}
ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
sharding={devices=[2,1,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 =
AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[19,38,324]"));
EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
}
TEST_F(SpmdPartitioningTest, NonShardableReshape) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[38,38,324] parameter(0)
%param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1}
ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy),
sharding={devices=[1,1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::DynamicSlice(
AllOf(op::Pad(
AllOf(op::Reshape(AllOf(op::AllReduce(),
op::Shape("f32[38,38,324]"))),
op::Shape("f32[38,38,4,81]")),
op::Constant()),
op::Shape("f32[38,38,4,82]")),
op::Constant(), op::Constant(), op::Constant(), op::Reshape()),
op::Shape("f32[38,38,4,41]")));
}
TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1}
ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
sharding={devices=[1,1,1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto reshape =
AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
auto halo = op::CollectivePermute(op::Slice(reshape));
auto exchanged =
op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _);
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
}
// Produces an invalid module after transformation.
TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
%param0 = f32[128,5,5,768] parameter(0)
%param0.copy = f32[128,5,5,768] copy(%param0),
sharding={devices=[1,4,1,1]0,1,2,3}
%constant.1 = f32[] constant(0), sharding={replicated}
ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1),
window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1},
to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto input_shard = op::Copy(op::DynamicSlice(
op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(),
op::Constant(), op::Constant()));
auto id_mul4_add1 =
op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
auto id_mul5 = op::Multiply(op::Reshape(), op::Constant());
auto id_mul5_add1_div3 =
op::Divide(op::Add(id_mul5, op::Constant()), op::Constant());
auto before_masking = AllOf(
op::Shape("f32[128,3,5,768]"),
op::DynamicSlice(
AllOf(
op::Shape("f32[128,4,5,768]"),
op::Concatenate(op::CollectivePermute(input_shard), input_shard)),
op::Constant(),
op::Subtract(op::Constant(),
op::Subtract(id_mul4_add1, id_mul5_add1_div3)),
op::Constant(), op::Constant()));
auto masked = op::Select(
op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
op::Broadcast(op::Constant())),
op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
op::Broadcast(op::Constant()))),
before_masking, op::Broadcast(op::Constant()));
auto rw = AllOf(op::Shape("f32[128,7,17,768]"),
op::ReduceWindow(masked, op::Constant()));
auto final_slice_index = op::Subtract(
id_mul5,
op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant()));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Shape("f32[128,5,17,768]"),
op::DynamicSlice(rw, op::Constant(), final_slice_index,
op::Constant(), op::Constant())));
}
TEST_F(SpmdPartitioningTest, TiledToTiledReduce) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
%param0 = f32[4,32,32,128] parameter(0)
%param0.copy = f32[4,32,32,128] copy(%param0),
sharding={devices=[1,1,1,2]0,1}
%constant.1 = f32[] constant(0), sharding={replicated}
%reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
to_apply=%sum, sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
op::Constant(), op::Reshape())),
op::Shape("f32[4,32,32,64]"));
EXPECT_THAT(root,
AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]")));
}
TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) {
const char* const hlo_string = R"(
HloModule module
%minmax_func {
%lhs_value = f32[] parameter(0)
%rhs_value = f32[] parameter(2)
%compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
%select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
%lhs_index = s32[] parameter(1)
%rhs_index = s32[] parameter(3)
%select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
}
ENTRY %main {
%param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1}
%param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1}
%init0 = f32[] parameter(2)
%init1 = s32[] parameter(3)
ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
dimensions={1}, to_apply=%minmax_func,
sharding={{devices=[2]0,1}, {devices=[2]0,1}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1),
op::Parameter(2), op::Parameter(3)),
op::Shape("(f32[14], s32[14])")));
}
TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
%param0 = f32[4,32,32,128] parameter(0)
%param0.copy = f32[4,32,32,128] copy(%param0),
sharding={devices=[1,2,1,1]0,1}
%constant.1 = f32[] constant(0), sharding={replicated}
%reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
to_apply=%sum, sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[4,16,32,128]"));
EXPECT_THAT(root,
AllOf(op::DynamicSlice(
AllOf(op::AllReduce(op::Reduce(param0, op::Constant())),
op::Shape("f32[128]")),
op::Reshape()),
op::Shape("f32[64]")));
}
TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
ROOT %iota = s32[16,80,91] iota(), iota_dimension=1,
sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]")));
}
TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
ROOT %iota = s32[16,80,91] iota(), iota_dimension=2,
sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
op::Shape("s32[16,80,46]")));
}
TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
ROOT %iota = u32[16,80,91] iota(), iota_dimension=2,
sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
op::Shape("u32[16,80,46]")));
}
TEST_F(SpmdPartitioningTest, Conditional) {
const char* const hlo_string = R"(
HloModule module
Negate {
x = f32[4,5] parameter(0), sharding={replicated}
ROOT negate = f32[4,5] negate(x), sharding={replicated}
}
Identity {
y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1}
ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1}
}
ENTRY entry {
%param.0 = pred[] parameter(0)
%param.0.copy = pred[] copy(%param.0), sharding={maximal device=0}
%param.1 = f32[4,5] parameter(1)
%param.1.copy = f32[4,5] copy(%param.1), sharding={replicated}
%param.2 = f32[4,5] parameter(2)
%param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1}
ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy),
true_computation=Negate, false_computation=Identity,
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]")));
auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]"));
auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant())),
op::Shape("f32[2,5]"));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2),
op::Shape("f32[2,5]")));
auto then_branch_root = root->branch_computation(0)->root_instruction();
EXPECT_THAT(then_branch_root,
AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(),
op::Constant()),
op::Shape("f32[2,5]")));
auto else_branch_root = root->branch_computation(1)->root_instruction();
EXPECT_THAT(else_branch_root,
AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]")));
}
TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) {
const char* const hlo_string = R"(
HloModule module
ge {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT compare = pred[] compare(a, b), direction=GE
}
sum {
c = f32[] parameter(0)
d = f32[] parameter(1)
ROOT add = f32[] add(c, d)
}
ENTRY entry {
%param.0 = f32[32,128,384,64] parameter(0)
%param.0.copy = f32[32,128,384,64] copy(%param.0),
sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
%param.1 = f32[32,64,192,64] parameter(1)
%param.1.copy = f32[32,64,192,64] copy(%param.1),
sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
constant.1 = f32[] constant(0), sharding={replicated}
ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy,
%param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1},
select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto source = AllOf(
op::Shape("f32[32,8,192,64]"),
op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())));
auto data = AllOf(
op::Shape("f32[32,16,384,64]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())));
EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant()));
EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
}
TEST_F(SpmdPartitioningTest, TiledDot) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,64] parameter(0)
%lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
%rhs = f32[64,256] parameter(1)
%rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
dim_labels=bf_io->bf, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/false));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
op::Reshape())),
op::Shape("f32[128,32]"));
auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant())),
op::Shape("f32[32,256]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
op::Shape("f32[128,256]")));
}
TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,64] parameter(0)
%lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
%rhs = f32[64,256] parameter(1)
%rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
dim_labels=bf_io->bf, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
op::Reshape())),
op::Shape("f32[128,32]"));
auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
op::Constant())),
op::Shape("f32[32,256]"));
EXPECT_THAT(root, AllOf(op::DynamicSlice(
AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
op::Shape("f32[128,256]")),
op::Constant(), op::Reshape()),
op::Shape("f32[128,128]")));
}
TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[128,256,256] parameter(0)
%lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1}
%rhs = f32[256,8,1] parameter(1)
%rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated}
ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy),
window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("f32[128,128,256]"));
auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]"));
EXPECT_THAT(root,
AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]")));
}
TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[24,64] parameter(0)
%lhs.copy = f32[24,64] copy(%lhs), sharding={replicated}
%rhs = f32[39296,64] parameter(1)
%rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1}
ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={}, rhs_batch_dims={},
lhs_contracting_dims={1}, rhs_contracting_dims={1},
sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]"));
auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
op::Constant())),
op::Shape("f32[19648,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]")));
}
TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64] parameter(0)
%lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
%rhs = f32[32,39296,64] parameter(1)
%rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[16,24,64]"));
auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[16,39296,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64] parameter(0)
%lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
%rhs = f32[32,39296,64] parameter(1)
%rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[16,24,64]"));
auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(),
op::Constant(),
op::Constant())),
op::Shape("f32[16,24,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64] parameter(0)
%lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1}
%rhs = f32[32,39296,64] parameter(1)
%rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("f32[32,12,64]"));
auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[16,39296,64]"));
auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))));
EXPECT_THAT(root,
AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64] parameter(0)
%lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
%rhs = f32[32,39296,64] parameter(1)
%rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,1,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs_slice =
AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(),
op::Constant(), op::Constant()),
op::Shape("f32[16,24,64]"));
auto rhs_slice =
AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(),
op::Constant(), op::Constant()),
op::Shape("f32[16,39296,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice),
op::Shape("f32[16,24,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3}
%rhs = f32[32,39296,64,128] parameter(1)
%rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Constant(), op::Reshape(), op::Reshape())),
op::Shape("f32[32,24,32,64]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
op::Constant(), op::Reshape(), op::Reshape())),
op::Shape("f32[32,39296,32,64]"));
EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
op::Shape("f32[32,24,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3}
%rhs = f32[32,39296,64] parameter(1)
%rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[1,2,2,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
op::Constant(), op::Reshape())),
op::Shape("f32[32,12,64,64]"));
auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64] parameter(0)
%lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
%rhs = f32[32,39296,64,128] parameter(1)
%rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3}
ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[1,1,2,2]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
op::Constant(), op::Reshape())),
op::Shape("f32[32,19648,64,64]"));
EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]")));
}
TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
%rhs = f32[32,39296,64,128] parameter(1)
%rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
auto rhs =
AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
EXPECT_THAT(
root,
AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(),
op::Constant(), op::Constant()),
op::Shape("f32[32,12,64,128]")),
rhs),
op::Shape("f32[32,12,39296]")));
}
TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
%rhs = f32[32,39296,64,128] parameter(1)
%rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
auto rhs =
AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
EXPECT_THAT(root,
AllOf(op::Dot(lhs, AllOf(op::DynamicSlice(
rhs, op::Constant(), op::Reshape(),
op::Constant(), op::Constant()),
op::Shape("f32[32,19648,64,128]"))),
op::Shape("f32[32,24,19648]")));
}
TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,39295,64,128] parameter(1)
%rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
/*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,12,64,128]"));
auto rhs =
AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,19648,64,128]"));
EXPECT_THAT(
root,
AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
lhs, rhs, op::Broadcast(), op::Constant()))),
op::Shape("f32[32,12,39296]"))),
op::Shape("f32[32,12,39295]")));
auto while_loop = root->operand(0)->operand(0);
// Check loop condition.
EXPECT_THAT(
while_loop->while_condition()->root_instruction(),
op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
// Check loop body.
auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
auto window = op::Conditional(op::Compare(next_i, op::Constant()),
op::GetTupleElement(op::Parameter(0)),
op::GetTupleElement(op::Parameter(0)));
auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)),
op::GetTupleElement(op::Parameter(0)));
EXPECT_THAT(
while_loop->while_body()->root_instruction(),
op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
partial_output, op::Constant(),
op::Constant(), op::Reshape()),
next_i));
// Check the conditional that contains the collective permute.
auto cp_conditional =
while_loop->while_body()->root_instruction()->operand(1);
EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
op::CollectivePermute(op::Parameter(0)));
EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
op::Parameter(0));
}
TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[32,24,63,128] parameter(0)
%lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,39296,63,128] parameter(1)
%rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
/*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
op::Constant(), op::Constant())),
op::Shape("f32[32,12,63,128]"));
auto rhs =
AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
op::Constant(), op::Constant(),
op::Reshape(), op::Constant())),
op::Shape("f32[32,39296,32,128]"));
auto masked_rhs =
op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
EXPECT_THAT(root,
AllOf(op::GetTupleElement(op::While(op::Tuple(
lhs, masked_rhs, op::Broadcast(), op::Constant()))),
op::Shape("f32[32,12,39296]")));
auto while_loop = root->operand(0);
// Check loop condition.
EXPECT_THAT(
while_loop->while_condition()->root_instruction(),
op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
// Check loop body.
auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
auto window = op::Conditional(op::Compare(next_i, op::Constant()),
op::GetTupleElement(op::Parameter(0)),
op::GetTupleElement(op::Parameter(0)));
auto partial_output = op::Dot(
op::DynamicSlice(
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
op::GetTupleElement(op::Parameter(0)));
EXPECT_THAT(
while_loop->while_body()->root_instruction(),
op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
op::Add(op::GetTupleElement(op::Parameter(0)), partial_output),
next_i));
// Check the conditional that contains the collective permute.
auto cp_conditional =
while_loop->while_body()->root_instruction()->operand(1);
EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
op::CollectivePermute(op::Parameter(0)));
EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
op::Parameter(0));
}
TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,39295,64,128] parameter(1)
%rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
%dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
%constant = f32[] constant(0)
%constant.1 = f32[] constant(2)
%broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
sharding={devices=[1,2,1]0,1}
%multiply = f32[32,24,39295] multiply(%dot, %broadcast),
sharding={devices=[1,2,1]0,1}
ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
to_apply=sum, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
/*num_devices=*/2));
VLOG(1) << module->ToString();
// Involves loop code motion, skips pattern matching.
}
TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) {
const char* const hlo_string = R"(
HloModule module
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add = f32[] add(a, b)
}
ENTRY entry {
%lhs = f32[32,24,64,128] parameter(0)
%lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
%rhs = f32[32,39295,64,128] parameter(1)
%rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
%dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
%constant = f32[] constant(0)
%constant.1 = f32[] constant(2)
%broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
sharding={devices=[1,2,1]0,1}
%multiply = f32[32,24,39295] multiply(%dot, %broadcast),
sharding={devices=[1,2,1]0,1}
ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
to_apply=sum, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
/*num_devices=*/2));
VLOG(1) << module->ToString();
// Involves loop code motion, skips pattern matching.
}
TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%rhs = f32[32,39296,63,128] parameter(0)
%rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
%constant.1 = f32[] constant(2)
%broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
sharding={devices=[1,2,1,1]0,1}
%add = f32[32,24,63,128] add(%broadcast, %broadcast),
sharding={devices=[1,2,1,1]0,1}
ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
/*num_devices=*/2));
VLOG(1) << module->ToString();
// Involves loop code motion, skips pattern matching.
}
TEST_F(SpmdPartitioningTest, ReplicatedRng) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = s32[] parameter(0)
%lhs.copy = s32[] copy(%lhs), sharding={replicated}
%rhs = s32[] parameter(1)
%rhs.copy = s32[] copy(%rhs), sharding={replicated}
ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
distribution=rng_uniform, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]"));
EXPECT_THAT(
root,
AllOf(op::AllReduce(op::Select(
op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
op::Rng(), op::Broadcast(op::Constant()))),
op::Shape("s32[4]")));
}
TEST_F(SpmdPartitioningTest, PartitionedRng) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = s32[] parameter(0)
%lhs.copy = s32[] copy(%lhs), sharding={replicated}
%rhs = s32[] parameter(1)
%rhs.copy = s32[] copy(%rhs), sharding={maximal device=1}
ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
distribution=rng_uniform, sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]"));
EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select(
op::Broadcast(op::Compare()), rhs,
op::Broadcast(op::Constant())))),
op::Shape("s32[2]")));
}
TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = s32[128,64] parameter(0)
%input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
%index = s32[] parameter(1)
%constant = s32[] constant(0)
ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index),
dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant())),
op::Shape("s32[64,64]"));
EXPECT_THAT(root,
AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)),
op::Shape("s32[64,2]")));
}
TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = s32[128,64] parameter(0)
%input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
%index = s32[] parameter(1)
%constant = s32[] constant(0)
%update = s32[128,2] parameter(2)
%update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1}
ROOT %dynamic-update-slice = s32[128,64]
dynamic-update-slice(%input.copy, %update.copy, %constant, %index),
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant())),
op::Shape("s32[64,64]"));
auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(),
op::Constant())),
op::Shape("s32[64,2]"));
EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(),
op::Parameter(1)),
op::Shape("s32[64,64]")));
}
TEST_F(SpmdPartitioningTest, PassthroughGather) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
%indices = s32[3] parameter(1), sharding={replicated}
ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
slice_sizes={1,9}, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
op::Shape("f32[3,5]")));
}
TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
%indices = s32[2,3] parameter(1), sharding={replicated}
ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
slice_sizes={1,9}, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto offset = op::Reshape(
op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
op::Shape("s32[2,3]"));
auto clamp = op::Clamp(min, op::Parameter(1), max);
auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
auto mask =
op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
auto masked =
op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
}
TEST_F(SpmdPartitioningTest, PassthroughScatter) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
op::Parameter(2)),
op::Shape("f32[2,5]")));
}
TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
%indices = s32[2,3] parameter(1), sharding={replicated}
%updates = f32[2,3,9] parameter(2), sharding={replicated}
ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={2},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=2, sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto offset = op::Reshape(
op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
auto indices = op::Subtract(
op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
op::Shape("f32[9,9]")));
}
TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
sharding={devices=[2,1]0,1}
ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1},
sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"),
op::Reverse(op::DynamicSlice(
op::Pad(op::Constant(), op::Constant()),
op::Reshape(), op::Constant()))));
}
TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[4] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[4] reverse(param), dimensions={0},
sharding={devices=[2]1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
}
TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[4] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[4] reverse(param), dimensions={0},
sharding={devices=[2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Shape("f32[2]"),
op::Reverse(op::CollectivePermute(op::Parameter(0)))));
}
TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[3] parameter(0), sharding={devices=[2]0,1}
ROOT reverse = f32[3] reverse(param), dimensions={0},
sharding={devices=[2]1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto halo_exchange_concat =
op::Concatenate(AllOf(op::Shape("f32[1]"),
op::CollectivePermute(op::Slice(op::Parameter(0)))),
op::Parameter(0));
auto after_halo_exchange = op::Slice(halo_exchange_concat);
EXPECT_THAT(root,
AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange)));
}
TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1}
to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated}
add = f32[4,2] add(to_shard, to_shard), sharding={replicated}
to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
auto to_shard = op::Copy(op::Parameter(0));
EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"),
op::Multiply(op::Copy(op::Add(to_shard, to_shard)),
op::Parameter(0))));
}
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8,8,8] parameter(0),
sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
ROOT %copy = f32[8,8,8,8] copy(%param0),
sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto reshape =
AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
EXPECT_THAT(root,
op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
4);
}
} // namespace
} // namespace spmd
} // namespace xla