Fix some ClangTidy warnings in Grappler's dependency optimizer on a boring Friday afternoon.
PiperOrigin-RevId: 452882239
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 1e628bc..b879ab0 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -416,6 +416,8 @@
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 88f1bd5..f2b2608 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -18,6 +18,7 @@
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/match.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -134,7 +135,7 @@
<< " to NoOp. Node has side effect.";
return false;
}
- if (node.op().rfind("Submodel", 0) == 0) {
+ if (absl::EndsWith(node.op(), "Submodel")) {
return false;
}
const OpDef* op_def = nullptr;
@@ -142,11 +143,12 @@
if (!status.ok() || op_def->output_arg_size() == 0) {
return false;
}
- const std::unordered_set<string> do_not_rewrite_ops{
- "Assert", "CheckNumerics", "_Retval",
- "_Arg", "_ParallelConcatUpdate", "TPUExecute",
- "TPUCompile", "ControlTrigger"};
- if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
+ static const absl::flat_hash_set<string>* do_not_rewrite_ops =
+ new absl::flat_hash_set<string>{
+ "Assert", "CheckNumerics", "_Retval",
+ "_Arg", "_ParallelConcatUpdate", "TPUExecute",
+ "TPUCompile", "ControlTrigger"};
+ if (do_not_rewrite_ops->find(node.op()) != do_not_rewrite_ops->end()) {
return false;
}
if (!SafeToRemoveIdentity(node)) {
@@ -750,7 +752,8 @@
GraphDef* optimized_graph) {
optimized_graph_ = optimized_graph;
*optimized_graph_ = item.graph;
- nodes_to_preserve_ = item.NodesToPreserve();
+ nodes_to_preserve_ = absl::flat_hash_set<string>(
+ item.NodesToPreserve().begin(), item.NodesToPreserve().end());
fetch_nodes_known_ = !item.fetch.empty();
CleanControlInputs();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index 4251a4a..4f37204 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -16,7 +16,8 @@
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
-#include <unordered_set>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@@ -73,9 +74,9 @@
void GroupCrossDeviceControlEdges(bool host_granularity);
bool fetch_nodes_known_;
- std::unordered_set<string> nodes_to_preserve_;
+ absl::flat_hash_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
- std::unordered_map<const NodeDef*, int> node_to_idx_;
+ absl::flat_hash_map<const NodeDef*, int> node_to_idx_;
GraphDef* optimized_graph_; // Not owned.
};
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index c958e4e..266e8c4 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -295,9 +295,11 @@
std::vector<string> StageNames() {
std::vector<string> names;
- for (const auto& stage : stages_) {
- names.push_back(stage->stage_name());
- }
+ std::transform(
+ stages_.begin(), stages_.end(), std::back_inserter(names),
+ [](const std::unique_ptr<GraphOptimizerStage<Result>>& stage) {
+ return stage->stage_name();
+ });
return names;
}