[KPL] Support tf.data.Dataset inputs to TextVectorization layer + unit tests
PiperOrigin-RevId: 277562293
Change-Id: I57f87305d59fe833c979dcd09b012f1a42acf65c
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
index 5c000f8..fed9244 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
@@ -23,6 +23,7 @@
import numpy as np
import six
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
@@ -325,11 +326,22 @@
"""
if not reset_state:
raise ValueError("TextVectorization does not support streaming adapts.")
- self.build(data.shape)
- # TODO(b/142870340): Look at passing preprocess as subgraph to dataset.map.
- preprocessed_inputs = self._preprocess(data)
- super(TextVectorization,
- self).adapt(self._to_numpy(preprocessed_inputs), reset_state)
+
+ # Build the layer explicitly with the original data shape instead of relying
+ # on an implicit call to `build` in the base layer's `adapt`, since
+ # preprocessing changes the input shape.
+ if isinstance(data, np.ndarray):
+ self.build(data.shape)
+ preprocessed_inputs = self._to_numpy(self._preprocess(data))
+ elif isinstance(data, dataset_ops.DatasetV2):
+ self.build(dataset_ops.get_legacy_output_shapes(data))
+ preprocessed_inputs = data.map(self._preprocess)
+ else:
+ raise ValueError(
+ "adapt() requires a Dataset or a Numpy array as input, got {}".format(
+ type(data)))
+
+ super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
def get_vocabulary(self):
if not self._has_vocab:
@@ -624,6 +636,8 @@
if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype):
raise RuntimeError("Expected input type %s, got %s" %
(self._input_dtype, values.dtype))
+ if ragged_tensor.is_ragged(values):
+ values = values.to_list()
flattened_batch = np.concatenate(values)
vocab, counts = np.unique(flattened_batch, return_counts=True)
if self._compute_idf:
@@ -647,8 +661,9 @@
[getattr(acc, _ACCUMULATOR_VOCAB_NAME) for acc in accumulators])
concat_counts = np.concatenate(
[getattr(acc, _ACCUMULATOR_COUNTS_NAME) for acc in accumulators])
- concat_document_counts = np.concatenate(
- [getattr(acc, _ACCUMULATOR_DOCUMENT_COUNTS) for acc in accumulators])
+ if self._compute_idf:
+ concat_document_counts = np.concatenate(
+ [getattr(acc, _ACCUMULATOR_DOCUMENT_COUNTS) for acc in accumulators])
merged_values, merged_indices = np.unique(concat_vocab, return_inverse=True)
def sum_segment(index, array_to_segment):
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
index cd56bd0..7f8d9ce 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py
@@ -28,6 +28,7 @@
from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized
@@ -48,12 +49,8 @@
return text_vectorization_v1.TextVectorization
-@keras_parameterized.run_all_keras_modes
-class TextVectorizationLayerTest(keras_parameterized.TestCase,
- preprocessing_test_utils.PreprocessingLayerTest
- ):
-
- @parameterized.named_parameters(
+def _get_end_to_end_test_cases():
+ test_cases = (
{
"testcase_name":
"test_simple_tokens_int_mode",
@@ -200,24 +197,64 @@
"split": text_vectorization.SPLIT_ON_WHITESPACE,
"output_mode": text_vectorization.TFIDF
},
- "expected_output":
- [[0., 0.847298, 0.847298, 0., 0.], [0., 0., 0., 1.098612, 0.],
- [0., 0., 0., 0., 2.197225], [1.609438, 0.847298, 0., 0., 0.]],
+ "expected_output": [[0., 0.847298, 0.847298, 0., 0.],
+ [0., 0., 0., 1.098612, 0.],
+ [0., 0., 0., 0., 2.197225],
+ [1.609438, 0.847298, 0., 0., 0.]],
},
)
+
+ crossed_test_cases = []
+ # Cross above test cases with use_dataset in (True, False)
+ for use_dataset in (True, False):
+ for case in test_cases:
+ case = case.copy()
+ if use_dataset:
+ case["testcase_name"] = case["testcase_name"] + "_with_dataset"
+ case["use_dataset"] = use_dataset
+ crossed_test_cases.append(case)
+
+ return crossed_test_cases
+
+
+@keras_parameterized.run_all_keras_modes
+class TextVectorizationLayerTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest
+ ):
+
+ @parameterized.named_parameters(*_get_end_to_end_test_cases())
def test_layer_end_to_end_with_adapt(self, vocab_data, input_data, kwargs,
- expected_output):
+ use_dataset, expected_output):
cls = get_layer_class()
if kwargs.get("output_mode") == text_vectorization.TFIDF:
expected_output_dtype = dtypes.float32
else:
expected_output_dtype = dtypes.int64
+ input_shape = input_data.shape
+
+ if use_dataset:
+ # Keras APIs expect batched datasets.
+ # TODO(rachelim): `model.predict` predicts the result on each
+ # dataset batch separately, then tries to concatenate the results
+ # together. When the results have different shapes on the non-concat
+ # axis (which can happen in the output_mode = INT case for
+ # TextVectorization), the concatenation fails. In real use cases, this may
+ # not be an issue because users are likely to pipe the preprocessing layer
+ # into other keras layers instead of predicting it directly. A workaround
+ # for these unit tests is to have the dataset only contain one batch, so
+ # no concatenation needs to happen with the result. For consistency with
+ # numpy input, we should make `predict` join differently shaped results
+ # together sensibly, with 0 padding.
+ input_data = dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+ input_shape[0])
+ vocab_data = dataset_ops.Dataset.from_tensor_slices(vocab_data).batch(
+ input_shape[0])
with CustomObjectScope({"TextVectorization": cls}):
output_data = testing_utils.layer_test(
cls,
kwargs=kwargs,
- input_shape=(None),
+ input_shape=input_shape,
input_data=input_data,
input_dtype=dtypes.string,
expected_output_dtype=expected_output_dtype,
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
index 2302023..43b4c7c 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
@@ -21,11 +21,13 @@
import numpy as np
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.ops.ragged import ragged_tensor_value
-class TextVectorization(text_vectorization.TextVectorization):
+class TextVectorization(text_vectorization.TextVectorization,
+ base_preprocessing_layer_v1.CombinerPreprocessingLayer):
"""Text vectorization layer.
This layer has basic options for managing text in a Keras model. It