blob: 034adbf44fe405889190efed337499a00652f8f9 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include <cstdlib>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/direct_session.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class TestListener : public XlaActivityListener {
public:
Status Listen(
const XlaAutoClusteringActivity& auto_clustering_activity) override {
auto_clustering_activity_ = auto_clustering_activity;
return Status::OK();
}
Status Listen(
const XlaJitCompilationActivity& jit_compilation_activity) override {
jit_compilation_activity_ = jit_compilation_activity;
return Status::OK();
}
Status Listen(const XlaOptimizationRemark& optimization_remark) override {
return Status::OK();
}
~TestListener() override {}
const XlaAutoClusteringActivity& auto_clustering_activity() const {
return auto_clustering_activity_;
}
const XlaJitCompilationActivity& jit_compilation_activity() const {
return jit_compilation_activity_;
}
private:
XlaAutoClusteringActivity auto_clustering_activity_;
XlaJitCompilationActivity jit_compilation_activity_;
};
class XlaActivityListenerTest : public ::testing::Test {
protected:
XlaActivityListenerTest() {
auto listener = absl::make_unique<TestListener>();
listener_ = listener.get();
RegisterXlaActivityListener(std::move(listener));
}
TestListener* listener() const { return listener_; }
private:
TestListener* listener_;
};
GraphDef CreateGraphDef() {
Scope root = Scope::NewRootScope().ExitOnError().WithAssignedDevice(
"/job:localhost/replica:0/task:0/device:CPU:0");
Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
for (int i = 0; i < 5; i++) {
a = ops::MatMul(root.WithOpName(absl::StrCat("matmul_", i)), a, a);
a = ops::Add(root.WithOpName(absl::StrCat("add_", i)), a, a);
}
GraphDef graph_def;
root.graph()->ToGraphDef(&graph_def);
return graph_def;
}
TEST_F(XlaActivityListenerTest, Test) {
GraphDef graph_def = CreateGraphDef();
SessionOptions options;
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(graph_def));
std::vector<std::string> output_names = {std::string("add_4:0")};
Tensor tensor_2x2(DT_FLOAT, TensorShape({2, 2}));
for (int i = 0; i < 4; i++) {
tensor_2x2.matrix<float>()(i / 2, i % 2) = 5 * i;
}
Tensor tensor_3x3(DT_FLOAT, TensorShape({3, 3}));
for (int i = 0; i < 9; i++) {
tensor_3x3.matrix<float>()(i / 3, i % 3) = 5 * i;
}
std::vector<std::pair<string, Tensor>> inputs_2x2 = {{"A", tensor_2x2}};
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run(inputs_2x2, output_names, /*target_node_names=*/{},
&outputs));
absl::string_view expected_auto_clustering_activity =
R"(global_jit_level: ON_2
cpu_global_jit_enabled: true
summary {
unclustered_node_count: 4
clustered_node_count: 14
clusters {
name: "cluster_0"
size: 14
op_histogram {
op: "Add"
count: 1
}
op_histogram {
op: "Const"
count: 4
}
op_histogram {
op: "MatMul"
count: 5
}
op_histogram {
op: "Mul"
count: 4
}
}
unclustered_op_histogram {
op: "NoOp"
count: 2
}
unclustered_op_histogram {
op: "_Arg"
count: 1
}
unclustered_op_histogram {
op: "_Retval"
count: 1
}
}
)";
EXPECT_EQ(listener()->auto_clustering_activity().DebugString(),
expected_auto_clustering_activity);
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 1);
int64 first_compile_time =
listener()->jit_compilation_activity().compile_time_us();
EXPECT_GT(first_compile_time, 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time);
std::vector<std::pair<string, Tensor>> inputs_3x3 = {{"A", tensor_3x3}};
outputs.clear();
for (int i = 0; i < 3; i++) {
TF_ASSERT_OK(session->Run(inputs_3x3, output_names,
/*target_node_names=*/{}, &outputs));
}
EXPECT_EQ(listener()->jit_compilation_activity().cluster_name(), "cluster_0");
EXPECT_EQ(listener()->jit_compilation_activity().compile_count(), 2);
EXPECT_GT(listener()->jit_compilation_activity().compile_time_us(), 0);
EXPECT_EQ(listener()->jit_compilation_activity().cumulative_compile_time_us(),
first_compile_time +
listener()->jit_compilation_activity().compile_time_us());
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) {
tensorflow::GetMarkForCompilationPassFlags()->tf_xla_cpu_global_jit = true;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}