blob: e697204967db0c3965f182bba9ab9fcaaf47d029 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <memory>
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
using ::testing::_;
using ::testing::NiceMock;
using ::testing::Return;
using ::testing::Test;
class MockMlirOptimizationPass : public MlirOptimizationPass {
public:
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
// instead.
MOCK_CONST_METHOD0(name, llvm::StringRef());
MOCK_CONST_METHOD4(GetPassState,
MlirOptimizationPassState(
const DeviceSet* device_set,
const ConfigProto& config_proto, const Graph& graph,
const FunctionLibraryDefinition& function_library));
MOCK_METHOD4(Run, Status(const ConfigProto& config_proto,
mlir::ModuleOp module, const Graph& graph,
const FunctionLibraryDefinition& function_library));
};
class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass {
public:
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
// instead.
MOCK_CONST_METHOD0(name, llvm::StringRef());
MOCK_CONST_METHOD4(GetPassState,
MlirOptimizationPassState(
const DeviceSet* device_set,
const ConfigProto& config_proto, const Graph& graph,
const FunctionLibraryDefinition& function_library));
MOCK_METHOD2(Run, Status(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module));
};
class ModifyMlirModulePass : public MlirOptimizationPass {
public:
explicit ModifyMlirModulePass(Status run_status) : run_status_(run_status) {}
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
// instead.
MOCK_CONST_METHOD0(name, llvm::StringRef());
MOCK_CONST_METHOD4(GetPassState,
MlirOptimizationPassState(
const DeviceSet* device_set,
const ConfigProto& config_proto, const Graph& graph,
const FunctionLibraryDefinition& function_library));
// Just modify MLIR module so that we can check whether original TF graph
// has changed or not.
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
const Graph& graph,
const FunctionLibraryDefinition& function_library) override {
mlir::Builder b(module.getContext());
auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0));
auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0));
auto bad_consumers =
b.getNamedAttr("bad_consumers", b.getI32ArrayAttr({1, 2, 3, 4}));
module->setAttr("tf.versions",
b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
{producer, min_consumer, bad_consumers})));
return run_status_;
}
Status run_status_;
};
class MlirGraphOptimizationPassTest : public Test {
public:
void Init(Status pass_run_result,
const std::vector<MlirOptimizationPassState>& pass_states) {
graph_ = std::make_unique<Graph>(OpRegistry::Global());
int pass_priority = 0;
for (const MlirOptimizationPassState& pass_state : pass_states) {
auto optimization_pass =
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
ON_CALL(*optimization_pass, GetPassState(_, _, _, _))
.WillByDefault(Return(pass_state));
ON_CALL(*optimization_pass, Run(_, _, _, _))
.WillByDefault(Return(pass_run_result));
MlirOptimizationPassRegistry::Global().Add(pass_priority++,
std::move(optimization_pass));
}
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
}
void AddModuleModificationPass(MlirOptimizationPassState pass_state,
Status run_status) {
// Add FallbackEnabled pass that modifies the graph.
auto optimization_pass =
std::make_unique<NiceMock<ModifyMlirModulePass>>(run_status);
ON_CALL(*optimization_pass, GetPassState(_, _, _, _))
.WillByDefault(Return(pass_state));
MlirOptimizationPassRegistry::Global().Add(10,
std::move(optimization_pass));
}
void TearDown() override {
MlirOptimizationPassRegistry::Global().ClearPasses();
}
void verifyGraph(const GraphDef& original_graph_def, bool changed = false) {
// Proto matchers might be unavailable in the OSS.
#if defined(PLATFORM_GOOGLE)
GraphDef resulted_graph_def;
graph_->ToGraphDef(&resulted_graph_def);
if (changed)
EXPECT_THAT(resulted_graph_def,
Not(::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def))));
else
EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def)));
#endif
}
ConfigProto config_proto_;
MlirFunctionOptimizationPass function_optimization_pass_;
DeviceSet device_set_;
std::unique_ptr<Graph> graph_;
std::unique_ptr<FunctionLibraryDefinition> flib_;
std::vector<std::string> control_ret_node_names_;
bool control_rets_updated_{false};
};
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoFallback) {
Init(Status(error::Code::ABORTED, "aborted"),
{MlirOptimizationPassState::Enabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted"));
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) {
Init(Status(error::Code::ABORTED, "aborted"),
{MlirOptimizationPassState::Disabled,
MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status(error::Code::ABORTED, "aborted"));
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def);
}
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassDoesNotFailFallback) {
Init(Status::OK(), {MlirOptimizationPassState::FallbackEnabled});
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled,
Status::OK());
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
verifyGraph(original_graph_def, true);
}
TEST(MlirOptimizationPassRegistry, RegisterPassesWithTheSamePriorityFails) {
MlirOptimizationPassRegistry::Global().Add(
0, std::make_unique<NiceMock<MockMlirOptimizationPass>>());
EXPECT_DEATH(MlirOptimizationPassRegistry::Global().Add(
0, std::make_unique<NiceMock<MockMlirOptimizationPass>>()),
"Pass priority must be unique.");
}
TEST(MlirV1CompatOptimizationPassRegistry, RegisterMultiplePassesFails) {
MlirV1CompatOptimizationPassRegistry::Global().Add(
std::make_unique<NiceMock<MockMlirV1CompatOptimizationPass>>());
EXPECT_DEATH(
MlirV1CompatOptimizationPassRegistry::Global().Add(
std::make_unique<NiceMock<MockMlirV1CompatOptimizationPass>>()),
"Only a single pass can be registered");
}
} // namespace tensorflow