Optimize attr lookup by using `AttrSlice::FindByString()` when possible.
`AttrSlice::Find()` performs a linear scan of the attrs and can't be changed until `protobuf::Map` supports efficient lookups using a `StringPiece`. But in many cases, `AttrSlice::Find()` is called with a `std::string`. Switching to `AttrSlice::FindByString()`, which uses an O(1) map lookup, improves performance when there are many attrs (e.g. some `CsvDataset`s, which have an attr for every column and don't scale well to CSVs with O(100k) columns).
PiperOrigin-RevId: 441823425
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 3c734d8..027b66c 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -72,7 +72,7 @@
bool* is_type_list, DataTypeVector* dtypes) {
dtypes->clear();
if (!arg_def.type_list_attr().empty()) {
- const AttrValue* v = attrs.Find(arg_def.type_list_attr());
+ const AttrValue* v = attrs.FindByString(arg_def.type_list_attr());
if (v == nullptr) {
return errors::NotFound("type list attr not found: ",
arg_def.type_list_attr());
@@ -87,7 +87,7 @@
*is_type_list = false;
int num = 1;
if (!arg_def.number_attr().empty()) {
- const AttrValue* v = attrs.Find(arg_def.number_attr());
+ const AttrValue* v = attrs.FindByString(arg_def.number_attr());
if (v == nullptr) {
return errors::NotFound("number attr not found: ", arg_def.number_attr());
}
@@ -100,7 +100,7 @@
} else if (arg_def.type_attr().empty()) {
dtype = DT_INVALID;
} else {
- const AttrValue* v = attrs.Find(arg_def.type_attr());
+ const AttrValue* v = attrs.FindByString(arg_def.type_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
@@ -121,7 +121,7 @@
// attr_values should specify all attrs defined in fdef, except for those
// which have a default value
for (const auto& attr : sig.attr()) {
- const AttrValue* attr_value = attr_values.Find(attr.name());
+ const AttrValue* attr_value = attr_values.FindByString(attr.name());
if (attr_value) {
Status status = AttrValueHasType(*attr_value, attr.type());
if (!status.ok()) {
@@ -776,9 +776,9 @@
}
}
- auto substitute = [attr_values, &sig](StringPiece name, AttrValue* val) {
+ auto substitute = [attr_values, &sig](const string& name, AttrValue* val) {
// Look for a specified value...
- if (const AttrValue* v = attr_values.Find(name)) {
+ if (const AttrValue* v = attr_values.FindByString(name)) {
*val = *v;
return true;
}
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 378caeb..940852b 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -181,10 +181,9 @@
}
}
-Status AttrSlice::Find(StringPiece attr_name,
- const AttrValue** attr_value) const {
- *attr_value = Find(attr_name);
- if (*attr_value != nullptr) {
+Status AttrSlice::CheckFind(StringPiece attr_name,
+ const AttrValue* attr_value) const {
+ if (attr_value != nullptr) {
return Status::OK();
}
Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
@@ -197,6 +196,18 @@
return s;
}
+Status AttrSlice::Find(StringPiece attr_name,
+ const AttrValue** attr_value) const {
+ *attr_value = Find(attr_name);
+ return CheckFind(attr_name, *attr_value);
+}
+
+Status AttrSlice::FindByString(const string& attr_name,
+ const AttrValue** attr_value) const {
+ *attr_value = FindByString(attr_name);
+ return CheckFind(attr_name, *attr_value);
+}
+
bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
if (size() != other.size()) return false;
@@ -483,13 +494,14 @@
}
} else if (!arg_def.type_attr().empty()) {
const AttrValue* attr_value;
- TF_RETURN_IF_ERROR(
- AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
+ TF_RETURN_IF_ERROR(AttrSlice(node_or_attrs)
+ .FindByString(arg_def.type_attr(), &attr_value));
sig->push_back(attr_value->type());
} else if (!arg_def.type_list_attr().empty()) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(
- AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
+ AttrSlice(node_or_attrs)
+ .FindByString(arg_def.type_list_attr(), &attr_value));
for (int dtype : attr_value->list().type()) {
sig->push_back(static_cast<DataType>(dtype));
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index db855b7..dcafd1a 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -158,6 +158,8 @@
// Returns the attr_value for attr_name if found. Otherwise, returns a
// NotFound status.
Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
+ Status FindByString(const std::string& attr_name,
+ const AttrValue** attr_value) const;
// Helper class to avoid allocations in EqualAttrs.
// TODO(irving): Will go away once NodeInfo is used.
@@ -193,6 +195,8 @@
return ndef_ != nullptr ? &ndef_->attr() : attrs_;
}
+ Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const;
+
const NodeDef* ndef_;
const AttrValueMap* attrs_;
};