blob: f3f0711d4e16c6abb05328927f869e0f5dbaed6e [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
using NVPTXCompilerTest = HloTestBase;
TEST_F(NVPTXCompilerTest, AllReducePerformedInplace) {
const absl::string_view hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0)
param1 = f32[128] parameter(1)
add = f32[128] add(param0, param1)
ROOT allreduce = f32[128] all-reduce(add), replica_groups={}, to_apply=summit
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
NVPTXCompiler compiler;
Compiler::CompileOptions compile_options;
TF_ASSERT_OK_AND_ASSIGN(auto module_and_buffer_assignment,
compiler.RunHloPassesAndBufferAssignement(
std::move(module),
/*executor=*/nullptr,
/*optimize=*/false, compile_options));
module = std::move(std::get<0>(module_and_buffer_assignment));
std::unique_ptr<BufferAssignment> buffer_assignment =
std::move(std::get<1>(module_and_buffer_assignment));
HloInstruction* all_reduce = module->entry_computation()->root_instruction();
ASSERT_EQ(
buffer_assignment->GetInstructionAllocation(all_reduce, {}),
buffer_assignment->GetInstructionAllocation(all_reduce->operand(0), {}));
}
TEST_F(NVPTXCompilerTest, AllReducePerformedInplaceTwoOperands) {
const absl::string_view hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0)
param1 = f32[128] parameter(1)
add = f32[128] add(param0, param1)
sub = f32[128] subtract(param0, param1)
ROOT allreduce = (f32[128], f32[128]) all-reduce(add, sub),
replica_groups={}, to_apply=summit
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
NVPTXCompiler compiler;
Compiler::CompileOptions compile_options;
TF_ASSERT_OK_AND_ASSIGN(auto module_and_buffer_assignment,
compiler.RunHloPassesAndBufferAssignement(
std::move(module),
/*executor=*/nullptr,
/*optimize=*/false, compile_options));
module = std::move(std::get<0>(module_and_buffer_assignment));
std::unique_ptr<BufferAssignment> buffer_assignment =
std::move(std::get<1>(module_and_buffer_assignment));
HloInstruction* all_reduce = module->entry_computation()->root_instruction();
ASSERT_EQ(
buffer_assignment->GetInstructionAllocation(all_reduce, {0}),
buffer_assignment->GetInstructionAllocation(all_reduce->operand(0), {}));
ASSERT_EQ(
buffer_assignment->GetInstructionAllocation(all_reduce, {1}),
buffer_assignment->GetInstructionAllocation(all_reduce->operand(1), {}));
}
} // namespace gpu
} // namespace xla