Add ConfigProto.experimental.disable_output_partition_graphs().
Currently, DirectSession retains a copy of the Graph for each partition of a subgraph, in order to be able to respond to a Run() call with `RunOptions.output_partition_graphs == true` or build a cost model (which is explicitly enabled via `ConfigProto.graph_options.build_cost_model > 0`). This experimental option makes it possible (in conjunction with not enabling cost models) to release those Graph copies.
PiperOrigin-RevId: 272658318
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0154b60..7effe58 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -754,12 +754,18 @@
// If requested via RunOptions, output the partition graphs.
if (run_options.output_partition_graphs()) {
- protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
- run_metadata->mutable_partition_graphs();
- for (const PerPartitionExecutorsAndLib& exec_and_lib :
- executors_and_keys->items) {
- GraphDef* partition_graph_def = partition_graph_defs->Add();
- exec_and_lib.graph->ToGraphDef(partition_graph_def);
+ if (options_.config.experimental().disable_output_partition_graphs()) {
+ return errors::InvalidArgument(
+ "RunOptions.output_partition_graphs() is not supported when "
+ "disable_output_partition_graphs is true.");
+ } else {
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+ run_metadata->mutable_partition_graphs();
+ for (const PerPartitionExecutorsAndLib& exec_and_lib :
+ executors_and_keys->items) {
+ GraphDef* partition_graph_def = partition_graph_defs->Add();
+ exec_and_lib.graph->ToGraphDef(partition_graph_def);
+ }
}
}
metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
@@ -1353,12 +1359,16 @@
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(),
partition_graph.get()));
- item->graph = std::move(partition_graph);
+
item->executor = nullptr;
item->device = device;
auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(
- NewExecutor(executor_type, params, *item->graph, &item->executor));
+ NewExecutor(executor_type, params, *partition_graph, &item->executor));
+ if (!options_.config.experimental().disable_output_partition_graphs() ||
+ options_.config.graph_options().build_cost_model() > 0) {
+ item->graph = std::move(partition_graph);
+ }
}
// Cache the mapping from input/output names to graph elements to
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 454f144..86eb109 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -226,6 +226,44 @@
absl::StrContains(s.error_message(), "optimize_for_static_graph"));
}
+TEST_F(DirectSessionMinusAXTest,
+ RunSimpleNetwork_DisableOutputPartitionGraphs) {
+ Initialize({3, 2, -1, 0});
+ SessionOptions options(DefaultSessionOptions());
+ options.config.mutable_experimental()->set_disable_output_partition_graphs(
+ true);
+ auto session = absl::WrapUnique(NewSession(options));
+
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+ Status s = session->Run(inputs, output_names, target_nodes, &outputs);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+
+ // The Run() call should fail when `output_partition_graphs` is set to true.
+ RunOptions run_options;
+ run_options.set_output_partition_graphs(true);
+ RunMetadata run_metadata;
+ s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
+ &run_metadata);
+
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(
+ absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
+}
+
TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
Initialize({3, 2, -1, 0});
auto session = CreateSession();
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index ca8b7a7..8d1532c 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1917,6 +1917,13 @@
std::unique_ptr<ProfileHandler> ph;
FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
+ if (pss.collect_partition_graphs &&
+ session_opts_.config.experimental().disable_output_partition_graphs()) {
+ return errors::InvalidArgument(
+ "RunOptions.output_partition_graphs() is not supported when "
+ "disable_output_partition_graphs is true.");
+ }
+
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
&cancellation_manager_, false);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index f0f6c3f..07bb8e3 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -286,6 +286,41 @@
TF_CHECK_OK(session->Close());
}
+TEST(GrpcSessionTest, DisableOutputPartitionGraphs) {
+ GraphDef graph;
+ string node_names[3];
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ SessionOptions options = Options(cluster->targets()[0], 1);
+ options.config.mutable_experimental()->set_disable_output_partition_graphs(
+ true);
+
+ std::unique_ptr<Session> session(NewRemote(options));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(graph));
+ {
+ // Just run to target node.
+ TF_CHECK_OK(session->Run({}, {}, {node_names[2]}, nullptr));
+ }
+ {
+ // Attempting to get the partition graphs should fail.
+ RunOptions run_options;
+ run_options.set_output_partition_graphs(true);
+ RunMetadata run_metadata;
+ Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr,
+ &run_metadata);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(absl::StrContains(s.error_message(),
+ "disable_output_partition_graphs"));
+ }
+
+ TF_CHECK_OK(session->Close());
+}
+
// A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
// eigenvalue for A, which is 2.0. Iteratively, we do
// repeat x = y / y.norm(); y = A * x; end
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index e8caa26..41b260f 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -553,6 +553,13 @@
// to an "execute" operation. The kernel for these operations is responsible
// to lower the encapsulated graph to a particular device.
bool enable_mlir_bridge = 13;
+
+ // If true, the session will not store an additional copy of the graph for
+ // each subgraph.
+ //
+ // If this option is set to true when a session is created, the
+ // `RunOptions.output_partition_graphs` options must not be set.
+ bool disable_output_partition_graphs = 14;
};
Experimental experimental = 16;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index 0ddd53d..b34809b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -75,6 +75,12 @@
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "disable_output_partition_graphs"
+ number: 14
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
reserved_range {
start: 2
end: 3
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index 6371374..db4ba6a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -198,6 +198,12 @@
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "disable_output_partition_graphs"
+ number: 14
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
reserved_range {
start: 2
end: 3