Add correctness test for ragged embeddings with the TPUEmbedding mid level API.

PiperOrigin-RevId: 354581897
Change-Id: Ifbbd351e6879d2ac0379520730510866e4bedf1f
diff --git a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
index 75235c3..c90ae95 100644
--- a/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
+++ b/tensorflow/python/tpu/tpu_embedding_v2_correctness_test.py
@@ -40,6 +40,7 @@
 from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import init_ops_v2
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import test
 from tensorflow.python.tpu import tpu_embedding_v2
 from tensorflow.python.tpu import tpu_embedding_v2_utils
@@ -149,12 +150,17 @@
   @parameterized.parameters(
       *itertools.product(
           ['sgd', 'adagrad', 'adam'],
+          [True, False],
           [True, False]))
-  def test_embedding(self, optimizer_name, training):
+  def test_embedding(self, optimizer_name, training, sparse):
     strategy, mid_level_api, optimizer = (
         self._create_strategy_and_mid_level(optimizer_name))
 
-    dataset = self._create_sparse_dataset(strategy)
+    if sparse:
+      dataset = self._create_sparse_dataset(strategy)
+    else:
+      dataset = self._create_ragged_dataset(strategy)
+
     dist = strategy.experimental_distribute_dataset(
         dataset,
         options=distribute_lib.InputOptions(
@@ -209,8 +215,7 @@
         feature_config=self.feature_config,
         optimizer=optimizer)
 
-  def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
-    # Create dataset for enqueue operation
+  def _create_sparse_data(self, include_weights, weight=0.5):
     sparse_features = (
         sparse_tensor.SparseTensor(
             indices=self.feature_watched_indices,
@@ -234,6 +239,11 @@
             values=values,
             dense_shape=sparse.dense_shape))
       sparse_features = (sparse_features, tuple(weights))
+    return sparse_features
+
+  def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
+    # Create dataset for enqueue operation
+    sparse_features = self._create_sparse_data(include_weights, weight)
 
     dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
 
@@ -241,6 +251,18 @@
     return dataset.unbatch().repeat().batch(
         self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
 
+  def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
+    # Create dataset for enqueue operation
+    sparse_features = self._create_sparse_data(include_weights, weight)
+    ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
+                                         sparse_features)
+
+    dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
+
+    # Data is batched to self.data_batch_size, rebatch to global batch size.
+    return dataset.unbatch().repeat().batch(
+        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
+
   def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
 
     def input_fn(ctx):
@@ -448,7 +470,8 @@
                numpy_users[self.feature_friends_values[-2:]]))
     self.assertAllClose(shard0, golden)
 
-  def test_sequence_embeddings(self):
+  @parameterized.parameters([True, False])
+  def test_sequence_embeddings(self, sparse):
     feature_config = (
         tpu_embedding_v2_utils.FeatureConfig(
             table=self.table_video, name='watched',
@@ -470,7 +493,10 @@
     # results in data where the shape of the sparse tensor is a tensor which we
     # can't tell the shape of at tracing time.
     mid_level.build(self.batch_size)
-    dataset = self._create_sparse_dataset(strategy)
+    if sparse:
+      dataset = self._create_sparse_dataset(strategy)
+    else:
+      dataset = self._create_ragged_dataset(strategy)
     data = next(iter(strategy.experimental_distribute_dataset(
         dataset,
         options=distribute_lib.InputOptions(