Add correctness test for ragged embeddings with the TPUEmbedding mid level API.
PiperOrigin-RevId: 354581897
Change-Id: Ifbbd351e6879d2ac0379520730510866e4bedf1f
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
index 75235c3..c90ae95 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
@@ -40,6 +40,7 @@
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
@@ -149,12 +150,17 @@
@parameterized.parameters(
*itertools.product(
['sgd', 'adagrad', 'adam'],
+ [True, False],
[True, False]))
- def test_embedding(self, optimizer_name, training):
+ def test_embedding(self, optimizer_name, training, sparse):
strategy, mid_level_api, optimizer = (
self._create_strategy_and_mid_level(optimizer_name))
- dataset = self._create_sparse_dataset(strategy)
+ if sparse:
+ dataset = self._create_sparse_dataset(strategy)
+ else:
+ dataset = self._create_ragged_dataset(strategy)
+
dist = strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
@@ -209,8 +215,7 @@
feature_config=self.feature_config,
optimizer=optimizer)
- def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
- # Create dataset for enqueue operation
+ def _create_sparse_data(self, include_weights, weight=0.5):
sparse_features = (
sparse_tensor.SparseTensor(
indices=self.feature_watched_indices,
@@ -234,6 +239,11 @@
values=values,
dense_shape=sparse.dense_shape))
sparse_features = (sparse_features, tuple(weights))
+ return sparse_features
+
+ def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
+ # Create dataset for enqueue operation
+ sparse_features = self._create_sparse_data(include_weights, weight)
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
@@ -241,6 +251,18 @@
return dataset.unbatch().repeat().batch(
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
+ def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
+ # Create dataset for enqueue operation
+ sparse_features = self._create_sparse_data(include_weights, weight)
+ ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
+ sparse_features)
+
+ dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
+
+ # Data is batched to self.data_batch_size, rebatch to global batch size.
+ return dataset.unbatch().repeat().batch(
+ self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
+
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
def input_fn(ctx):
@@ -448,7 +470,8 @@
numpy_users[self.feature_friends_values[-2:]]))
self.assertAllClose(shard0, golden)
- def test_sequence_embeddings(self):
+ @parameterized.parameters([True, False])
+ def test_sequence_embeddings(self, sparse):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched',
@@ -470,7 +493,10 @@
# results in data where the shape of the sparse tensor is a tensor which we
# can't tell the shape of at tracing time.
mid_level.build(self.batch_size)
- dataset = self._create_sparse_dataset(strategy)
+ if sparse:
+ dataset = self._create_sparse_dataset(strategy)
+ else:
+ dataset = self._create_ragged_dataset(strategy)
data = next(iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(