blob: df1ba164befe21112fa3b2aa36cd60a49c725e99 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
namespace {
class CublasGemmPadForTensorCoresTest : public HloTestBase {};
TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f16[2048,1024] parameter(0)
%param2 = f16[1024,33708] parameter(1)
ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
f16[1024,33708]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})")
.ValueOrDie();
EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
auto* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(
op::Shape("f16[2048, 33708]"),
op::Slice(AllOf(
op::Shape("f16[2048, 33712]"),
op::Dot(AllOf(op::Shape("f16[2048, 1024]"),
op::Pad(AllOf(op::Shape("f16[2048, 1024]"),
op::Parameter()),
AllOf(op::Shape("f16[]"), op::Constant()))),
AllOf(op::Shape("f16[1024, 33712]"),
op::Pad(AllOf(op::Shape("f16[1024, 33708]"),
op::Parameter()),
AllOf(op::Shape("f16[]"), op::Constant()))),
/*lhs_contracting_dim=*/1,
/*rhs_contracting_dim=*/0)))));
}
TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f16[2048,1024] parameter(0)
%param2 = f16[1024,33708] parameter(1)
%param3 = f16[33708, 1] parameter(2)
%dot1 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
f16[1024,33708]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048,33708]{1,0} %dot1,
f16[33708, 1]{0,1} %param3),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})")
.ValueOrDie();
EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
SCOPED_TRACE(module->ToString());
auto* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(
op::Shape("f16[2048, 1]"),
op::Slice(AllOf(
op::Shape("f16[2048, 8]"),
op::Dot(
AllOf(
op::Shape("f16[2048, 33712]"),
AllOf(
op::Shape("f16[2048, 33712]"),
AllOf(
op::Shape("f16[2048, 33712]"),
op::Pad(
AllOf(op::Shape("f16[2048, 33708]"),
op::Slice(AllOf(
op::Shape("f16[2048, 33712]"),
op::Dot(
AllOf(op::Shape(
"f16[2048, 1024]"),
op::Pad()),
AllOf(op::Shape(
"f16[1024, 33712]"),
op::Pad()),
1, 0)))),
AllOf(op::Shape("f16[]"), op::Constant()))))),
AllOf(op::Shape("f16[33712, 8]"),
AllOf(op::Shape("f16[33712, 8]"),
op::Pad(
AllOf(op::Shape("f16[33708, 1]"),
op::Parameter()),
AllOf(op::Shape("f16[]"), op::Constant())))),
/*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)))));
auto* dot2 = root->operand(0)->operand(0)->operand(0)->operand(0);
EXPECT_THAT(
dot2,
AllOf(op::Dot(
AllOf(op::Shape("f16[2048, 1024]"),
op::Pad(AllOf(op::Shape("f16[2048, 1024]"), op::Parameter()),
AllOf(op::Shape("f16[]"), op::Constant()))),
AllOf(op::Shape("f16[1024, 33712]"),
op::Pad(AllOf(op::Shape("f16[1024, 33708]"), op::Parameter()),
AllOf(op::Shape("f16[]"), op::Constant()))),
/*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
}
TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y)
})")
.ValueOrDie();
EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
}
TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f32[2048,1024] parameter(0)
%param2 = f32[1024,33708] parameter(1)
ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1,
f32[1024,33708]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
.ValueOrDie();
EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
}
TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f64[2048,1024] parameter(0)
%param2 = f64[1024,33708] parameter(1)
ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1,
f64[1024,33708]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
.ValueOrDie();
EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
}
TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f16[2048,1024] parameter(0)
%param2 = f16[1024,33712] parameter(1)
ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1,
f16[1024,33712]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}})")
.ValueOrDie();
EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
}
TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule TestModule
ENTRY TestComputation {
%param1 = f16[2048,1024] parameter(0)
%param2 = f16[1024,33708] parameter(1)
ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1,
f16[1024,33708]{0,1} %param2),
lhs_contracting_dims={1}, rhs_contracting_dims={0},
metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"}
})")
.ValueOrDie();
SCOPED_TRACE(module->ToString());
EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie());
auto metadata = module->entry_computation()->root_instruction()->metadata();
EXPECT_EQ("MatMul", metadata.op_type());
EXPECT_EQ(
"transformer_v2/Transformer/decode/embedding_shared_weights_1/"
"presoftmax_linear/MatMul",
metadata.op_name());
}
} // anonymous namespace
} // namespace gpu
} // namespace xla