Add ConfigProto.experimental.disable_output_partition_graphs().

Currently, DirectSession retains a copy of the Graph for each partition of a subgraph, in order to be able to respond to a Run() call with `RunOptions.output_partition_graphs == true` or build a cost model (which is explicitly enabled via `ConfigProto.graph_options.build_cost_model > 0`). This experimental option makes it possible (in conjunction with not enabling cost models) to release those Graph copies.

PiperOrigin-RevId: 272658318
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0154b60..7effe58 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -754,12 +754,18 @@
 
   // If requested via RunOptions, output the partition graphs.
   if (run_options.output_partition_graphs()) {
-    protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
-        run_metadata->mutable_partition_graphs();
-    for (const PerPartitionExecutorsAndLib& exec_and_lib :
-         executors_and_keys->items) {
-      GraphDef* partition_graph_def = partition_graph_defs->Add();
-      exec_and_lib.graph->ToGraphDef(partition_graph_def);
+    if (options_.config.experimental().disable_output_partition_graphs()) {
+      return errors::InvalidArgument(
+          "RunOptions.output_partition_graphs() is not supported when "
+          "disable_output_partition_graphs is true.");
+    } else {
+      protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+          run_metadata->mutable_partition_graphs();
+      for (const PerPartitionExecutorsAndLib& exec_and_lib :
+           executors_and_keys->items) {
+        GraphDef* partition_graph_def = partition_graph_defs->Add();
+        exec_and_lib.graph->ToGraphDef(partition_graph_def);
+      }
     }
   }
   metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
@@ -1353,12 +1359,16 @@
     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
                                          device->name(),
                                          partition_graph.get()));
-    item->graph = std::move(partition_graph);
+
     item->executor = nullptr;
     item->device = device;
     auto executor_type = options_.config.experimental().executor_type();
     TF_RETURN_IF_ERROR(
-        NewExecutor(executor_type, params, *item->graph, &item->executor));
+        NewExecutor(executor_type, params, *partition_graph, &item->executor));
+    if (!options_.config.experimental().disable_output_partition_graphs() ||
+        options_.config.graph_options().build_cost_model() > 0) {
+      item->graph = std::move(partition_graph);
+    }
   }
 
   // Cache the mapping from input/output names to graph elements to
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 454f144..86eb109 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -226,6 +226,44 @@
       absl::StrContains(s.error_message(), "optimize_for_static_graph"));
 }
 
+TEST_F(DirectSessionMinusAXTest,
+       RunSimpleNetwork_DisableOutputPartitionGraphs) {
+  Initialize({3, 2, -1, 0});
+  SessionOptions options(DefaultSessionOptions());
+  options.config.mutable_experimental()->set_disable_output_partition_graphs(
+      true);
+  auto session = absl::WrapUnique(NewSession(options));
+
+  ASSERT_TRUE(session != nullptr);
+  TF_ASSERT_OK(session->Create(def_));
+  std::vector<std::pair<string, Tensor>> inputs;
+
+  // Request two targets: one fetch output and one non-fetched output.
+  std::vector<string> output_names = {y_ + ":0"};
+  std::vector<string> target_nodes = {y_neg_};
+  std::vector<Tensor> outputs;
+  Status s = session->Run(inputs, output_names, target_nodes, &outputs);
+  TF_ASSERT_OK(s);
+
+  ASSERT_EQ(1, outputs.size());
+  // The first output should be initialized and have the correct
+  // output.
+  auto mat = outputs[0].matrix<float>();
+  ASSERT_TRUE(outputs[0].IsInitialized());
+  EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+
+  // The Run() call should fail when `output_partition_graphs` is set to true.
+  RunOptions run_options;
+  run_options.set_output_partition_graphs(true);
+  RunMetadata run_metadata;
+  s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
+                   &run_metadata);
+
+  EXPECT_TRUE(errors::IsInvalidArgument(s));
+  EXPECT_TRUE(
+      absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
+}
+
 TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
   Initialize({3, 2, -1, 0});
   auto session = CreateSession();
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index ca8b7a7..8d1532c 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1917,6 +1917,13 @@
   std::unique_ptr<ProfileHandler> ph;
   FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
 
+  if (pss.collect_partition_graphs &&
+      session_opts_.config.experimental().disable_output_partition_graphs()) {
+    return errors::InvalidArgument(
+        "RunOptions.output_partition_graphs() is not supported when "
+        "disable_output_partition_graphs is true.");
+  }
+
   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
                                 &cancellation_manager_, false);
 
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index f0f6c3f..07bb8e3 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -286,6 +286,41 @@
   TF_CHECK_OK(session->Close());
 }
 
+TEST(GrpcSessionTest, DisableOutputPartitionGraphs) {
+  GraphDef graph;
+  string node_names[3];
+  CreateGraphDef(&graph, node_names);
+
+  std::unique_ptr<test::TestCluster> cluster;
+  TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+  SessionOptions options = Options(cluster->targets()[0], 1);
+  options.config.mutable_experimental()->set_disable_output_partition_graphs(
+      true);
+
+  std::unique_ptr<Session> session(NewRemote(options));
+  ASSERT_TRUE(session != nullptr);
+
+  TF_CHECK_OK(session->Create(graph));
+  {
+    // Just run to target node.
+    TF_CHECK_OK(session->Run({}, {}, {node_names[2]}, nullptr));
+  }
+  {
+    // Attempting to get the partition graphs should fail.
+    RunOptions run_options;
+    run_options.set_output_partition_graphs(true);
+    RunMetadata run_metadata;
+    Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr,
+                            &run_metadata);
+    EXPECT_TRUE(errors::IsInvalidArgument(s));
+    EXPECT_TRUE(absl::StrContains(s.error_message(),
+                                  "disable_output_partition_graphs"));
+  }
+
+  TF_CHECK_OK(session->Close());
+}
+
 // A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
 // eigenvalue for A, which is 2.0. Iteratively, we do
 //   repeat x = y / y.norm(); y = A * x; end
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index e8caa26..41b260f 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -553,6 +553,13 @@
     // to an "execute" operation. The kernel for these operations is responsible
     // to lower the encapsulated graph to a particular device.
     bool enable_mlir_bridge = 13;
+
+    // If true, the session will not store an additional copy of the graph for
+    // each subgraph.
+    //
+    // If this option is set to true when a session is created, the
+    // `RunOptions.output_partition_graphs` options must not be set.
+    bool disable_output_partition_graphs = 14;
   };
 
   Experimental experimental = 16;
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index 0ddd53d..b34809b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -75,6 +75,12 @@
       label: LABEL_OPTIONAL
       type: TYPE_BOOL
     }
+    field {
+      name: "disable_output_partition_graphs"
+      number: 14
+      label: LABEL_OPTIONAL
+      type: TYPE_BOOL
+    }
     reserved_range {
       start: 2
       end: 3
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index 6371374..db4ba6a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -198,6 +198,12 @@
         label: LABEL_OPTIONAL
         type: TYPE_BOOL
       }
+      field {
+        name: "disable_output_partition_graphs"
+        number: 14
+        label: LABEL_OPTIONAL
+        type: TYPE_BOOL
+      }
       reserved_range {
         start: 2
         end: 3