[KPL] Support tf.data.Dataset inputs to TextVectorization layer + unit tests

PiperOrigin-RevId: 277562293
Change-Id: I57f87305d59fe833c979dcd09b012f1a42acf65c
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
index 5c000f8..fed9244 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
@@ -23,6 +23,7 @@
 import numpy as np
 import six
 
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
@@ -325,11 +326,22 @@
     """
     if not reset_state:
       raise ValueError("TextVectorization does not support streaming adapts.")
-    self.build(data.shape)
-    # TODO(b/142870340): Look at passing preprocess as subgraph to dataset.map.
-    preprocessed_inputs = self._preprocess(data)
-    super(TextVectorization,
-          self).adapt(self._to_numpy(preprocessed_inputs), reset_state)
+
+    # Build the layer explicitly with the original data shape instead of relying
+    # on an implicit call to `build` in the base layer's `adapt`, since
+    # preprocessing changes the input shape.
+    if isinstance(data, np.ndarray):
+      self.build(data.shape)
+      preprocessed_inputs = self._to_numpy(self._preprocess(data))
+    elif isinstance(data, dataset_ops.DatasetV2):
+      self.build(dataset_ops.get_legacy_output_shapes(data))
+      preprocessed_inputs = data.map(self._preprocess)
+    else:
+      raise ValueError(
+          "adapt() requires a Dataset or a Numpy array as input, got {}".format(
+              type(data)))
+
+    super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
 
   def get_vocabulary(self):
     if not self._has_vocab:
@@ -624,6 +636,8 @@
     if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype):
       raise RuntimeError("Expected input type %s, got %s" %
                          (self._input_dtype, values.dtype))
+    if ragged_tensor.is_ragged(values):
+      values = values.to_list()
     flattened_batch = np.concatenate(values)
     vocab, counts = np.unique(flattened_batch, return_counts=True)
     if self._compute_idf:
@@ -647,8 +661,9 @@
         [getattr(acc, _ACCUMULATOR_VOCAB_NAME) for acc in accumulators])
     concat_counts = np.concatenate(
         [getattr(acc, _ACCUMULATOR_COUNTS_NAME) for acc in accumulators])
-    concat_document_counts = np.concatenate(
-        [getattr(acc, _ACCUMULATOR_DOCUMENT_COUNTS) for acc in accumulators])
+    if self._compute_idf:
+      concat_document_counts = np.concatenate(
+          [getattr(acc, _ACCUMULATOR_DOCUMENT_COUNTS) for acc in accumulators])
     merged_values, merged_indices = np.unique(concat_vocab, return_inverse=True)
 
     def sum_segment(index, array_to_segment):
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
index cd56bd0..7f8d9ce 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
@@ -28,6 +28,7 @@
 
 from tensorflow.python import keras
 
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras import keras_parameterized
@@ -48,12 +49,8 @@
     return text_vectorization_v1.TextVectorization
 
 
-@keras_parameterized.run_all_keras_modes
-class TextVectorizationLayerTest(keras_parameterized.TestCase,
-                                 preprocessing_test_utils.PreprocessingLayerTest
-                                ):
-
-  @parameterized.named_parameters(
+def _get_end_to_end_test_cases():
+  test_cases = (
       {
           "testcase_name":
               "test_simple_tokens_int_mode",
@@ -200,24 +197,64 @@
               "split": text_vectorization.SPLIT_ON_WHITESPACE,
               "output_mode": text_vectorization.TFIDF
           },
-          "expected_output":
-              [[0., 0.847298, 0.847298, 0., 0.], [0., 0., 0., 1.098612, 0.],
-               [0., 0., 0., 0., 2.197225], [1.609438, 0.847298, 0., 0., 0.]],
+          "expected_output": [[0., 0.847298, 0.847298, 0., 0.],
+                              [0., 0., 0., 1.098612, 0.],
+                              [0., 0., 0., 0., 2.197225],
+                              [1.609438, 0.847298, 0., 0., 0.]],
       },
   )
+
+  crossed_test_cases = []
+  # Cross above test cases with use_dataset in (True, False)
+  for use_dataset in (True, False):
+    for case in test_cases:
+      case = case.copy()
+      if use_dataset:
+        case["testcase_name"] = case["testcase_name"] + "_with_dataset"
+      case["use_dataset"] = use_dataset
+      crossed_test_cases.append(case)
+
+  return crossed_test_cases
+
+
+@keras_parameterized.run_all_keras_modes
+class TextVectorizationLayerTest(keras_parameterized.TestCase,
+                                 preprocessing_test_utils.PreprocessingLayerTest
+                                ):
+
+  @parameterized.named_parameters(*_get_end_to_end_test_cases())
   def test_layer_end_to_end_with_adapt(self, vocab_data, input_data, kwargs,
-                                       expected_output):
+                                       use_dataset, expected_output):
     cls = get_layer_class()
     if kwargs.get("output_mode") == text_vectorization.TFIDF:
       expected_output_dtype = dtypes.float32
     else:
       expected_output_dtype = dtypes.int64
+    input_shape = input_data.shape
+
+    if use_dataset:
+      # Keras APIs expect batched datasets.
+      # TODO(rachelim): `model.predict` predicts the result on each
+      # dataset batch separately, then tries to concatenate the results
+      # together. When the results have different shapes on the non-concat
+      # axis (which can happen in the output_mode = INT case for
+      # TextVectorization), the concatenation fails. In real use cases, this may
+      # not be an issue because users are likely to pipe the preprocessing layer
+      # into other keras layers instead of predicting it directly. A workaround
+      # for these unit tests is to have the dataset only contain one batch, so
+      # no concatenation needs to happen with the result. For consistency with
+      # numpy input, we should make `predict` join differently shaped results
+      # together sensibly, with 0 padding.
+      input_data = dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+          input_shape[0])
+      vocab_data = dataset_ops.Dataset.from_tensor_slices(vocab_data).batch(
+          input_shape[0])
 
     with CustomObjectScope({"TextVectorization": cls}):
       output_data = testing_utils.layer_test(
           cls,
           kwargs=kwargs,
-          input_shape=(None),
+          input_shape=input_shape,
           input_data=input_data,
           input_dtype=dtypes.string,
           expected_output_dtype=expected_output_dtype,
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
index 2302023..43b4c7c 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
@@ -21,11 +21,13 @@
 import numpy as np
 
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import base_preprocessing_layer_v1
 from tensorflow.python.keras.layers.preprocessing import text_vectorization
 from tensorflow.python.ops.ragged import ragged_tensor_value
 
 
-class TextVectorization(text_vectorization.TextVectorization):
+class TextVectorization(text_vectorization.TextVectorization,
+                        base_preprocessing_layer_v1.CombinerPreprocessingLayer):
   """Text vectorization layer.
 
   This layer has basic options for managing text in a Keras model. It