Add support for serializing mutable hash tables as GraphDefs.
The serialization will store the key and value data into the graph.
This CL also fixes an issue in LookupTableImport type inference, where we previously asserted that the keys and values had the same rank. This was incorrect, because values may be either scalars or Tensors.
PiperOrigin-RevId: 366154768
Change-Id: I348a3e7e0ad660863df6ca106f3ebd69a3e8e95e
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index b246d4a..a5eba24 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -131,14 +131,7 @@
ctx->allocate_output("keys", TensorShape({size}), &keys));
TF_RETURN_IF_ERROR(
ctx->allocate_output("values", TensorShape({size}), &values));
-
- auto keys_data = keys->flat<K>();
- auto values_data = values->flat<V>();
- int64 i = 0;
- for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
- keys_data(i) = it->first;
- values_data(i) = it->second;
- }
+ ExportKeysAndValues(keys, values);
return Status::OK();
}
@@ -164,7 +157,57 @@
return sizeof(MutableHashTableOfScalars) + ret;
}
+ Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
+ tf_shared_lock l(mu_);
+ int64 size = table_.size();
+ Tensor keys(key_dtype(), TensorShape({size}));
+ Tensor values(value_dtype(), TensorShape({size}));
+ ExportKeysAndValues(&keys, &values);
+
+ // We set use_node_name_sharing with a unique node name so that the resource
+ // can outlive the MutableHashTableV2 kernel. This means that the lifetime
+ // of the resource will be tied to the lifetime of the resource manager it
+ // is created in.
+ // TODO(b/181695913): Provide a mechanism for deleting this resource
+ // earlier when appropriate.
+ Node* table = ops::SourceOp(
+ "MutableHashTableV2",
+ builder->opts()
+ .WithName(UniqueNodeName("MutableHashTableFromGraphDef"))
+ .WithAttr("use_node_name_sharing", true)
+ .WithAttr("key_dtype", key_dtype())
+ .WithAttr("value_dtype", value_dtype()));
+ Node* keys_node = ops::SourceOp(
+ "Const",
+ builder->opts().WithAttr("dtype", key_dtype()).WithAttr("value", keys));
+ Node* values_node =
+ ops::SourceOp("Const", builder->opts()
+ .WithAttr("dtype", value_dtype())
+ .WithAttr("value", values));
+ Node* import_table =
+ ops::TernaryOp("LookupTableImportV2", table, keys_node, values_node,
+ builder->opts()
+ .WithAttr("Tin", key_dtype())
+ .WithAttr("Tout", value_dtype()));
+ *out = ops::UnaryOp("Identity", table,
+ builder->opts().WithControlInput(import_table));
+ return Status::OK();
+ }
+
private:
+ // Writes all keys and values into `keys` and `values`. `keys` and `values`
+ // must point to tensors of size `table_.size()`.
+ void ExportKeysAndValues(Tensor* keys, Tensor* values) const
+ TF_SHARED_LOCKS_REQUIRED(mu_) {
+ auto keys_data = keys->flat<K>();
+ auto values_data = values->flat<V>();
+ int64 i = 0;
+ for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
+ keys_data(i) = it->first;
+ values_data(i) = it->second;
+ }
+ }
+
mutable mutex mu_;
std::unordered_map<K, V> table_ TF_GUARDED_BY(mu_);
};
@@ -276,18 +319,7 @@
ctx->allocate_output("keys", TensorShape({size}), &keys));
TF_RETURN_IF_ERROR(ctx->allocate_output(
"values", TensorShape({size, value_dim}), &values));
-
- auto keys_data = keys->flat<K>();
- auto values_data = values->matrix<V>();
- int64 i = 0;
- for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
- K key = it->first;
- ValueArray value = it->second;
- keys_data(i) = key;
- for (int64 j = 0; j < value_dim; j++) {
- values_data(i, j) = value[j];
- }
- }
+ ExportKeysAndValues(keys, values);
return Status::OK();
}
@@ -313,7 +345,63 @@
return sizeof(MutableHashTableOfTensors) + ret;
}
+ Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
+ tf_shared_lock l(mu_);
+ int64 size = table_.size();
+ Tensor keys(key_dtype(), TensorShape({size}));
+ Tensor values(value_dtype(), TensorShape({size, value_shape_.dim_size(0)}));
+ ExportKeysAndValues(&keys, &values);
+
+ // We set use_node_name_sharing with a unique node name so that the resource
+ // can outlive the MutableHashTableOfTensorsV2 kernel. This means that the
+ // lifetime of the resource will be tied to the lifetime of the resource
+ // manager it is created in.
+ // TODO(b/181695913): Provide a mechanism for deleting this resource
+ // earlier when appropriate.
+ Node* table =
+ ops::SourceOp("MutableHashTableOfTensorsV2",
+ builder->opts()
+ .WithName(UniqueNodeName("MutableHashTableOfTensors"))
+ .WithAttr("use_node_name_sharing", true)
+ .WithAttr("key_dtype", key_dtype())
+ .WithAttr("value_dtype", value_dtype())
+ .WithAttr("value_shape", value_shape_));
+ Node* keys_node = ops::SourceOp(
+ "Const",
+ builder->opts().WithAttr("dtype", key_dtype()).WithAttr("value", keys));
+ Node* values_node =
+ ops::SourceOp("Const", builder->opts()
+ .WithAttr("dtype", value_dtype())
+ .WithAttr("value", values));
+ Node* import_table =
+ ops::TernaryOp("LookupTableImportV2", table, keys_node, values_node,
+ builder->opts()
+ .WithAttr("Tin", key_dtype())
+ .WithAttr("Tout", value_dtype()));
+ *out = ops::UnaryOp("Identity", table,
+ builder->opts().WithControlInput(import_table));
+ return Status::OK();
+ }
+
private:
+ // Writes all keys and values into `keys` and `values`. `keys` and `values`
+ // must point to tensors of size `table_.size()`.
+ void ExportKeysAndValues(Tensor* keys, Tensor* values) const
+ TF_SHARED_LOCKS_REQUIRED(mu_) {
+ int64 value_dim = value_shape_.dim_size(0);
+ auto keys_data = keys->flat<K>();
+ auto values_data = values->matrix<V>();
+ int64 i = 0;
+ for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
+ K key = it->first;
+ ValueArray value = it->second;
+ keys_data(i) = key;
+ for (int64 j = 0; j < value_dim; j++) {
+ values_data(i, j) = value[j];
+ }
+ }
+ }
+
TensorShape value_shape_;
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index f9aea52..4da1fd7 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -309,7 +309,9 @@
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
- TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(keys, 0), c->Dim(c->input(2), 0), &unused));
return Status::OK();
});
diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
index 6293f5a..21fa281 100644
--- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py
@@ -120,6 +120,30 @@
self.evaluate(lookup_ops.tables_initializer())
self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)
+ @combinations.generate(
+ combinations.times(test_base.default_test_combinations(),
+ combinations.combine(value_rank=[0, 1])))
+ def testDistributeMutableHashTable(self, value_rank):
+
+ def value(v):
+ for _ in range(value_rank):
+ v = [v, v]
+ return v
+
+ v1 = value(10)
+ v2 = value(11)
+ default_value = value(-1)
+
+ cluster = data_service_test_base.TestCluster(num_workers=1)
+ table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
+ default_value)
+ self.evaluate(table.insert([0, 1], [v1, v2]))
+ ds = dataset_ops.Dataset.range(3)
+ ds = ds.map(table.lookup)
+ ds = self.make_distributed_dataset(ds, cluster)
+ self.assertDatasetProduces(
+ ds, [v1, v2, default_value], requires_initialization=True)
+
@combinations.generate(test_base.default_test_combinations())
def testDifferentShuffleOrders(self):
random_seed.set_random_seed(None)