[tf.data] Only apply the optimization `enable_gradient_descent` on the main dataset. Also update the unit test.
PiperOrigin-RevId: 335486882
Change-Id: I864f44d90fc3d84f702b3e53ed6b226362dd27a7
diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
index fa31be8..4ece165 100644
--- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
+++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc
@@ -31,6 +31,8 @@
constexpr char kAlgorithm[] = "algorithm";
constexpr char kModelDataset[] = "ModelDataset";
+constexpr char kRetValOp[] = "_Retval";
+
constexpr int64 HILL_CLIMB = 0;
constexpr int64 GRADIENT_DESCENT = 1;
@@ -47,9 +49,21 @@
}
MutableGraphView graph(output);
- int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
+ for (const auto& fetch_name : item.fetch) {
+ // If the GrapplerItem is derived from a FunctionDef, we don't optimize it,
+ // because we only want to enable gradient descent on the main dataset
+ // pipeline.
+ auto fetch = graph.GetNode(fetch_name);
+ if (fetch == nullptr || fetch->op() == kRetValOp) {
+ // Heuristic: If the fetch nodes are Retval ops, this item is from a
+ // function.
+ return Status::OK();
+ }
+ }
+ int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output);
NodeDef& model_node = *(output->mutable_node(index));
+
if (model_node.attr().at(kAlgorithm).i() == HILL_CLIMB) {
(*model_node.mutable_attr())[kAlgorithm].set_i(GRADIENT_DESCENT);
stats->num_changes++;
diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
index 86d66f3..c623c53 100644
--- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc
@@ -41,12 +41,13 @@
return optimizer.Optimize(nullptr, item, output);
}
-class SimpleRewrite : public ::testing::TestWithParam<std::tuple<bool, int64>> {
-};
+class SimpleRewrite
+ : public ::testing::TestWithParam<std::tuple<bool, int64, string>> {};
TEST_P(SimpleRewrite, EnableGradientDescentTest) {
const bool autotune = std::get<0>(GetParam());
const int64 algorithm_index = std::get<1>(GetParam());
+ const string op = std::get<2>(GetParam());
using test::function::NDef;
GrapplerItem item;
@@ -58,7 +59,9 @@
NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}),
NDef("batch", "BatchDataset", {"range", "batch_size"}, {}),
NDef("model", "ModelDataset", {"batch"},
- {{"algorithm", algorithm_index}})});
+ {{"algorithm", algorithm_index}}),
+ NDef("Sink", op, {"model"}, {})});
+ item.fetch.push_back("Sink");
GraphDef output;
TF_ASSERT_OK(OptimizeWithEnableGradientDescent(item, &output, autotune));
@@ -67,12 +70,13 @@
NodeDef model_node =
output.node(graph_utils::FindGraphNodeWithName("model", output));
EXPECT_EQ(model_node.attr().at("algorithm").i(),
- autotune ? 1 : algorithm_index);
+ (autotune && op != "_Retval") ? 1 : algorithm_index);
}
-INSTANTIATE_TEST_SUITE_P(Test, SimpleRewrite,
- ::testing::Combine(::testing::Values(false, true),
- ::testing::Values(0, 1)));
+INSTANTIATE_TEST_SUITE_P(
+ Test, SimpleRewrite,
+ ::testing::Combine(::testing::Values(false, true), ::testing::Values(0, 1),
+ ::testing::Values("Identity", "_Retval")));
} // namespace
} // namespace grappler
} // namespace tensorflow