Add a Lookup layer.

PiperOrigin-RevId: 292968344
Change-Id: I7dadf31138366d5b0ad8ed7eb0a885647d25bc81
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d416af6..999f3b89 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -67,6 +67,31 @@
 )
 
 py_library(
+    name = "index_lookup",
+    srcs = [
+        "index_lookup.py",
+        "index_lookup_v1.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:lookup_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:string_ops",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tensor_spec",
+        "//tensorflow/python:util",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/keras:backend",
+        "//tensorflow/python/keras/engine:base_preprocessing_layer",
+        "//tensorflow/python/ops/ragged",
+    ],
+)
+
+py_library(
     name = "normalization",
     srcs = [
         "normalization.py",
@@ -93,6 +118,7 @@
     srcs_version = "PY2AND3",
     deps = [
         ":categorical_encoding",
+        ":index_lookup",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:dtypes",
@@ -169,6 +195,22 @@
     ],
 )
 
+tf_py_test(
+    name = "index_lookup_test",
+    size = "medium",
+    srcs = ["index_lookup_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":index_lookup",
+        ":preprocessing_test_utils",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python/keras",
+        "//tensorflow/python/keras/utils:generic_utils",
+        "//tensorflow/python/ops/ragged:ragged_string_ops",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
 cuda_py_test(
     name = "image_preprocessing_test",
     size = "medium",
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
new file mode 100644
index 0000000..1fb4b6c
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
@@ -0,0 +1,402 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras text vectorization preprocessing layer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import operator
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
+from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops.ragged import ragged_functional_ops
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.util import compat
+
+# The string tokens in the extracted vocabulary
+_VOCAB_NAME = "vocab"
+
+# The string tokens in the full vocabulary
+_ACCUMULATOR_VOCAB_NAME = "vocab"
+# The total counts of each token in the vocabulary
+_ACCUMULATOR_COUNTS_NAME = "counts"
+
+
+class IndexLookup(CombinerPreprocessingLayer):
+  """Maps strings (or integers) from a vocabulary to integer indices.
+
+  This layer translates a set of arbitray strings or integers into an integer
+  output via a table-based lookup, with optional out-of-vocabulary handling.
+
+  If desired, the user can call this layer's `adapt()` method on a data set,
+  which will analyze the data set, determine the frequency of individual string
+  or integer values, and create a vocabulary from them. This vocabulary can have
+  unlimited size or be capped, depending on the configuration options for this
+  layer; if there are more unique values in the input than the maximum
+  vocabulary size, the most frequent terms will be used to create the
+  vocabulary.
+
+  Attributes:
+    max_tokens: The maximum size of the vocabulary for this layer. If None,
+      there is no cap on the size of the vocabulary. Note that the vocabulary
+      does include OOV buckets, so the effective number of unique values in the
+      vocabulary is `(max_tokens - num_oov_tokens)` when this value is set.
+    num_oov_tokens: The number of out-of-vocabulary tokens to use; defaults to
+      1. If this value is more than 1, OOV inputs are hashed to determine their
+      OOV value; if this value is 0, passing an OOV input will result in a
+      runtime error.
+    reserve_zero: Whether to reserve the index 0, which indicates pad values in
+      the Keras masking system. If True, the output of this layer will be in the
+      range `[1...max_tokens+1)`; if False, the output will be in the range
+      `[0...max_tokens)`. Defaults to True.
+    mask_zero: If True, input values of 0 (for integers) and `""` (for strings)
+      will be treated as masked values and assigned an output value of 0. If
+      this option is set, `reserve_zero` must also be set. Defaults to False.
+  """
+  # TODO(momernick): Add an examples section to the docstring.
+
+  def __init__(self,
+               max_tokens,
+               num_oov_tokens=1,
+               reserve_zero=True,
+               mask_zero=False,
+               **kwargs):
+    allowed_dtypes = [dtypes.string, dtypes.int64]
+    if "dtype" in kwargs and kwargs["dtype"] not in allowed_dtypes:
+      raise ValueError(
+          "TextVectorization may only have a dtype of string or int64.")
+    elif "dtype" not in kwargs:
+      kwargs["dtype"] = dtypes.string
+
+    # If max_tokens is set, the value must be greater than 1 - otherwise we
+    # are creating a 0-element vocab, which doesn't make sense.
+    if max_tokens is not None and max_tokens <= 1:
+      raise ValueError("max_tokens must be greater than 1.")
+
+    # For now, limit the num_oov_tokens to one.
+    if num_oov_tokens != 1:
+      raise ValueError("num_oov_tokens must be 1 for the time being. Other "
+                       "values will be supported in the near future. "
+                       "You passed %s" % num_oov_tokens)
+
+    self.max_tokens = max_tokens
+    self.num_oov_tokens = num_oov_tokens
+    self.reserve_zero = reserve_zero
+    self.mask_zero = mask_zero
+
+    # We need to reserve at least num_oov_tokens tokens, plus one additional
+    # value if we are reserving the zero value in our output.
+    if reserve_zero:
+      self._reserved_values = (num_oov_tokens + 1)
+    else:
+      self._reserved_values = num_oov_tokens
+
+    # We need to account for the OOV buckets in our vocabulary size.
+    if max_tokens is not None:
+      self._max_elements = max_tokens - num_oov_tokens
+    else:
+      self._max_elements = None
+
+    # If there is only one OOV bucket, we can determine the OOV value (either 0
+    # or 1 depending on whether 0 is reserved) and set that as the default
+    # value of the index_lookup table. If we hav multiple OOV values, we need to
+    # do a further hashing step; to make this easier, we set the OOV value to
+    # -1. (This lets us do a vectorized add and cast to boolean to determine
+    # locations where we need to do extra hashing.)
+    if self.num_oov_tokens == 1:
+      self._oov_value = 1 if reserve_zero else 0
+    else:
+      self._oov_value = -1
+
+    super(IndexLookup, self).__init__(
+        combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
+
+    # This layer supports RaggedTensor inputs.
+    self._supports_ragged_inputs = True
+
+    # If the layer's input type is int32, we can only output int32 values -
+    # MutableHashTable doesn't allow us to map int32->int64.
+    if self.dtype == dtypes.int32:
+      self._output_dtype = dtypes.int32
+    else:
+      self._output_dtype = dtypes.int64
+
+    self._table = lookup_ops.MutableHashTable(
+        key_dtype=self.dtype,
+        value_dtype=self._output_dtype,
+        default_value=self._oov_value,
+        name=(self._name + "_index_table"))
+
+    # This is a workaround for saving not working yet for MutableHashTables.
+    # By replacing the existing function call by an explicit failure, we
+    # can provide a more user-friendly error message.
+    def fail(_):
+      raise NotImplementedError(
+          "Saving is not yet supported for IndexLookup layers.")
+
+    self._table._list_extra_dependencies_for_serialization = fail  # pylint: disable=protected-access
+
+    tracked_table = self._add_trackable(self._table, trainable=False)
+    # This is a workaround for summary() on this layer. Because the table is
+    # not mutable during training, the effective number of parameters (and so
+    # the weight shape) is 0; we add this as an attr so that the parameter
+    # counting code in the Model object doesn't throw an attribute error.
+    tracked_table.shape = tensor_shape.TensorShape((0,))
+
+  def _get_table_data(self):
+    keys, values = self._table.export()
+    return (keys.numpy(), values.numpy())
+
+  def vocab_size(self):
+    return self._table.size().numpy()
+
+  def _clear_table(self):
+    keys, _ = self._table.export()
+    self._table.remove(keys)
+
+  def _insert_table_data(self, keys, values):
+    if len(values) != len(keys):
+      raise RuntimeError("Size mismatch between values and key arrays. "
+                         "Keys had size %s, values had size %s." %
+                         (len(keys), len(values)))
+    self._table.insert(keys, values)
+
+  def _to_numpy(self, preprocessed_data):
+    """Converts preprocessed inputs into numpy arrays."""
+    if isinstance(preprocessed_data, np.ndarray):
+      return preprocessed_data
+    return np.array(preprocessed_data.to_list())
+  # End of V1/V2 shim points.
+
+  def _assert_same_type(self, expected_type, values, value_name):
+    if dtypes.as_dtype(expected_type) != dtypes.as_dtype(values.dtype):
+      raise RuntimeError("Expected %s type %s, got %s" %
+                         (value_name, expected_type, values.dtype))
+
+  def _convert_to_ndarray(self, x):
+    return np.array(x) if isinstance(x, (list, tuple)) else x
+
+  def compute_output_shape(self, input_shape):
+    return input_shape
+
+  def compute_output_signature(self, input_spec):
+    output_shape = self.compute_output_shape(input_spec.shape.as_list())
+    output_dtype = dtypes.int64
+    return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
+
+  def adapt(self, data, reset_state=True):
+    """Fits the state of the preprocessing layer to the dataset.
+
+    Overrides the default adapt method to apply relevant preprocessing to the
+    inputs before passing to the combiner.
+
+    Arguments:
+      data: The data to train on. It can be passed either as a tf.data Dataset,
+        or as a numpy array.
+      reset_state: Optional argument specifying whether to clear the state of
+        the layer at the start of the call to `adapt`. This must be True for
+        this layer, which does not support repeated calls to `adapt`.
+    """
+    if not reset_state:
+      raise ValueError("IndexLookup does not support streaming adapts.")
+    super(IndexLookup, self).adapt(data, reset_state)
+
+  def get_vocabulary(self):
+    if self.vocab_size() == 0:
+      return []
+
+    keys, values = self._get_table_data()
+    # This is required because the MutableHashTable doesn't preserve insertion
+    # order, but we rely on the order of the array to assign indices.
+    return [x for _, x in sorted(zip(values, keys))]
+
+  def get_config(self):
+    config = {
+        "max_tokens": self.max_tokens,
+        "num_oov_tokens": self.num_oov_tokens,
+        "reserve_zero": self.reserve_zero,
+        "mask_zero": self.mask_zero
+    }
+    base_config = super(IndexLookup, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+  def count_params(self):
+    # This method counts the number of scalars in the weights of this layer.
+    # Since this layer doesn't have any /actual/ weights (in that there's
+    # nothing in this layer that can be trained - we only use the weight
+    # abstraction for ease of saving!) we return 0.
+    return 0
+
+  def set_vocabulary(self,
+                     vocab,
+                     append=False):
+    """Sets vocabulary (and optionally document frequency) data for this layer.
+
+    This method sets the vocabulary for this layer directly, instead of
+    analyzing a dataset through 'adapt'. It should be used whenever the vocab
+    information is already known. If vocabulary data is already present in the
+    layer, this method will either replace it, if 'append' is set to False, or
+    append to it (if 'append' is set to True).
+
+    Arguments:
+      vocab: An array of string tokens.
+      append: Whether to overwrite or append any existing vocabulary data.
+
+    Raises:
+      ValueError: If there are too many inputs, the inputs do not match, or
+        input data is missing.
+    """
+    current_table_size = self.vocab_size()
+    total_vocab_size = len(vocab) + (current_table_size if append else 0)
+    if self.max_tokens is not None and total_vocab_size > self._max_elements:
+      raise ValueError(
+          "Attempted to set a vocabulary larger than the maximum vocab size. "
+          "Passed vocab size is %s, max vocab size is %s. Note that the OOV "
+          "token(s) are automatically added to the number of tokens." %
+          (total_vocab_size, self.max_tokens))
+
+    start_index = self._reserved_values + (self.vocab_size() if append else 0)
+    values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
+
+    vocab = self._convert_to_ndarray(vocab)
+    self._assert_same_type(self.dtype, vocab, "vocab")
+
+    values = self._convert_to_ndarray(values)
+    self._assert_same_type(self._output_dtype, values, "values")
+
+    if not append and self.vocab_size() > 0:
+      self._clear_table()
+    self._insert_table_data(vocab, values)
+
+  def _set_state_variables(self, updates):
+    if not self.built:
+      raise RuntimeError("_set_state_variables() must be called after build().")
+    self.set_vocabulary(updates[_VOCAB_NAME])
+
+  def call(self, inputs):
+    # The table lookup ops don't natively support ragged tensors, so if we have
+    # a RT we need to use map_flat_values to look up every element.
+    if ragged_tensor.is_ragged(inputs):
+      indexed_data = ragged_functional_ops.map_flat_values(
+          self._table.lookup, inputs)
+    else:
+      indexed_data = self._table.lookup(inputs)
+
+    return indexed_data
+
+
+class _IndexLookupCombiner(Combiner):
+  """Combiner for the IndexLookup preprocessing layer.
+
+  This class encapsulates the logic for computing a vocabulary based on the
+  frequency of each token.
+
+  Attributes:
+    vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on
+      frequency across the dataset) are retained in the vocabulary. If None, or
+      set to a value greater than the total number of distinct tokens in the
+      dataset, all tokens are retained.
+  """
+  ACCUMULATOR_CLS = collections.namedtuple("Accumulator", ["count_dict"])
+
+  def __init__(self, vocab_size=None):
+    self._vocab_size = vocab_size
+
+  def compute(self, values, accumulator=None):
+    """Compute a step in this computation, returning a new accumulator."""
+    if ragged_tensor.is_ragged(values):
+      values = values.to_list()
+    if isinstance(values, ops.EagerTensor):
+      values = values.numpy()
+    if isinstance(values, np.ndarray):
+      values = values.tolist()
+
+    if accumulator is None:
+      accumulator = self._create_accumulator()
+
+    # TODO(momernick): Benchmark improvements to this algorithm.
+    for document in values:
+      for token in document:
+        accumulator.count_dict[token] += 1
+
+    return accumulator
+
+  def merge(self, accumulators):
+    """Merge several accumulators to a single accumulator."""
+    if not accumulators:
+      return accumulators
+
+    base_accumulator = accumulators[0]
+    for accumulator in accumulators[1:]:
+      for token, value in accumulator.count_dict.items():
+        base_accumulator.count_dict[token] += value
+
+    return base_accumulator
+
+  def extract(self, accumulator):
+    """Convert an accumulator into a dict of output values.
+
+    Args:
+      accumulator: An accumulator aggregating over the full dataset.
+
+    Returns:
+      A dict of:
+        "vocab": A list of the retained items in the vocabulary.
+    """
+    vocab_counts = accumulator.count_dict
+    sorted_counts = sorted(
+        vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True)
+    vocab_data = (
+        sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts)
+    vocab = [data[0] for data in vocab_data]
+    return {_VOCAB_NAME: vocab}
+
+  def restore(self, output):
+    """Create an accumulator based on 'output'."""
+    raise NotImplementedError(
+        "IndexLookup does not restore or support streaming updates.")
+
+  def serialize(self, accumulator):
+    """Serialize an accumulator for a remote call."""
+    output_dict = {}
+    output_dict["vocab"] = list(accumulator.count_dict.keys())
+    output_dict["vocab_counts"] = list(accumulator.count_dict.values())
+    return compat.as_bytes(json.dumps(output_dict))
+
+  def deserialize(self, encoded_accumulator):
+    """Deserialize an accumulator received from 'serialize()'."""
+    accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
+
+    accumulator = self._create_accumulator()
+    count_dict = dict(
+        zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
+    accumulator.count_dict.update(count_dict)
+
+    return accumulator
+
+  def _create_accumulator(self):
+    """Accumulate a sorted array of vocab tokens and corresponding counts."""
+
+    count_dict = collections.defaultdict(int)
+    return self.ACCUMULATOR_CLS(count_dict)
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
new file mode 100644
index 0000000..67bbe80
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
@@ -0,0 +1,481 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras text vectorization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
+from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
+from tensorflow.python.keras.saving import saved_model_experimental as saving
+from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.platform import test
+
+
+def get_layer_class():
+  if context.executing_eagerly():
+    return index_lookup.IndexLookup
+  else:
+    return index_lookup_v1.IndexLookup
+
+
+def _get_end_to_end_test_cases():
+  test_cases = (
+      {
+          "testcase_name":
+              "test_strings_soft_vocab_cap",
+          # Create an array where 'earth' is the most frequent term, followed by
+          # 'wind', then 'and', then 'fire'. This ensures that the vocab
+          # accumulator is sorting by frequency.
+          "vocab_data":
+              np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
+                        ["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
+          "input_data":
+              np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
+                        ["and"], ["earth"], ["michigan"]]),
+          "kwargs": {
+              "max_tokens": None,
+          },
+          "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
+          "input_dtype":
+              dtypes.string
+      },
+      {
+          "testcase_name":
+              "test_ints_soft_vocab_cap",
+          # Create an array where 1138 is the most frequent term, followed by
+          # 1729, then 725, then 42. This ensures that the vocab accumulator
+          # is sorting by frequency.
+          "vocab_data":
+              np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729],
+                        [1729], [725], [725]]),
+          "input_data":
+              np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]]),
+          "kwargs": {
+              "max_tokens": None,
+              "dtype": dtypes.int64,
+          },
+          "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
+          "input_dtype":
+              dtypes.int64
+      },
+  )
+
+  crossed_test_cases = []
+  # Cross above test cases with use_dataset in (True, False)
+  for use_dataset in (True, False):
+    for case in test_cases:
+      case = case.copy()
+      if use_dataset:
+        case["testcase_name"] = case["testcase_name"] + "_with_dataset"
+      case["use_dataset"] = use_dataset
+      crossed_test_cases.append(case)
+
+  return crossed_test_cases
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupLayerTest(keras_parameterized.TestCase,
+                           preprocessing_test_utils.PreprocessingLayerTest):
+
+  @parameterized.named_parameters(*_get_end_to_end_test_cases())
+  def test_layer_end_to_end_with_adapt(self, vocab_data, input_data, kwargs,
+                                       use_dataset, expected_output,
+                                       input_dtype):
+    cls = get_layer_class()
+    expected_output_dtype = dtypes.int64
+    input_shape = input_data.shape
+
+    if use_dataset:
+      # Keras APIs expect batched datasets.
+      # TODO(rachelim): `model.predict` predicts the result on each
+      # dataset batch separately, then tries to concatenate the results
+      # together. When the results have different shapes on the non-concat
+      # axis (which can happen in the output_mode = INT case for
+      # IndexLookup), the concatenation fails. In real use cases, this may
+      # not be an issue because users are likely to pipe the preprocessing layer
+      # into other keras layers instead of predicting it directly. A workaround
+      # for these unit tests is to have the dataset only contain one batch, so
+      # no concatenation needs to happen with the result. For consistency with
+      # numpy input, we should make `predict` join differently shaped results
+      # together sensibly, with 0 padding.
+      input_data = dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+          input_shape[0])
+      vocab_data = dataset_ops.Dataset.from_tensor_slices(vocab_data).batch(
+          input_shape[0])
+
+    with CustomObjectScope({"IndexLookup": cls}):
+      output_data = testing_utils.layer_test(
+          cls,
+          kwargs=kwargs,
+          input_shape=input_shape,
+          input_data=input_data,
+          input_dtype=input_dtype,
+          expected_output_dtype=expected_output_dtype,
+          validate_training=False,
+          adapt_data=vocab_data)
+    self.assertAllClose(expected_output, output_data)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupOutputTest(keras_parameterized.TestCase,
+                            preprocessing_test_utils.PreprocessingLayerTest):
+
+  def test_int_output(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+    input_array = np.array([["earth", "wind", "and", "fire"],
+                            ["fire", "and", "earth", "michigan"]])
+    expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=None)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    output_dataset = model.predict(input_array)
+    self.assertAllEqual(expected_output, output_dataset)
+
+  def test_int_output_no_reserved_zero(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+    input_array = np.array([["earth", "wind", "and", "fire"],
+                            ["fire", "and", "earth", "michigan"]])
+    expected_output = [[1, 2, 3, 4], [4, 3, 1, 0]]
+
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=None, reserve_zero=False)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    output_dataset = model.predict(input_array)
+    self.assertAllEqual(expected_output, output_dataset)
+
+  def test_vocab_appending(self):
+    vocab_data = [["earth", "wind"], ["and", "fire"]]
+    input_array = np.array([["earth", "wind", "and", "fire"],
+                            ["fire", "and", "earth", "michigan"]])
+    expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=5)
+    layer.set_vocabulary(vocab_data[0])
+    layer.set_vocabulary(vocab_data[1], append=True)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    output_dataset = model.predict(input_array)
+    self.assertAllClose(expected_output, output_dataset)
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
+class IndexLookupSaveableTest(keras_parameterized.TestCase,
+                              preprocessing_test_utils.PreprocessingLayerTest):
+
+  def test_ops_are_not_added_with_multiple_get_set_weights(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=10)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    weights = model.get_weights()
+    model.set_weights(weights)
+    keras.backend.get_session().graph.finalize()
+    weights = model.get_weights()
+    model.set_weights(weights)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupErrorTest(keras_parameterized.TestCase,
+                           preprocessing_test_utils.PreprocessingLayerTest):
+
+  def test_too_long_vocab_fails_in_single_setting(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+
+    layer = get_layer_class()(max_tokens=4)
+    with self.assertRaisesRegex(ValueError,
+                                "vocabulary larger than the maximum vocab.*"):
+      layer.set_vocabulary(vocab_data)
+
+  def test_too_long_vocab_fails_in_multiple_settings(self):
+    vocab_data = [["earth", "wind"], ["and", "fire"]]
+    layer = get_layer_class()(max_tokens=4)
+
+    # The first time we call set_vocabulary, we're under the max_tokens
+    # so it should be fine.
+    layer.set_vocabulary(vocab_data[0])
+    with self.assertRaisesRegex(ValueError,
+                                "vocabulary larger than the maximum vocab.*"):
+      layer.set_vocabulary(vocab_data[1], append=True)
+
+  def test_zero_max_tokens_fails(self):
+    with self.assertRaisesRegex(ValueError, ".*max_tokens.*"):
+      _ = get_layer_class()(max_tokens=0)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupSavingTest(keras_parameterized.TestCase,
+                            preprocessing_test_utils.PreprocessingLayerTest):
+
+  def test_saving_errors(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+
+    # Build and validate a golden model.
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=None)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+
+    # Save the model to disk.
+    output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+
+    with self.assertRaisesRegex(NotImplementedError, ".*Saving is not yet.*"):
+      model.save(output_path, save_format="tf")
+
+  def test_saving_errors_when_nested(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+
+    # Build and validate a golden model.
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=None)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+
+    outer_input = keras.Input(shape=(None,), dtype=dtypes.string)
+    outer_output = model(outer_input)
+    outer_model = keras.Model(inputs=outer_input, outputs=outer_output)
+
+    # Save the model to disk.
+    output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+
+    with self.assertRaisesRegex(NotImplementedError, ".*Saving is not yet.*"):
+      outer_model.save(output_path, save_format="tf")
+
+  def DISABLED_test_vocabulary_persistence_across_saving(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+    input_array = np.array([["earth", "wind", "and", "fire"],
+                            ["fire", "and", "earth", "michigan"]])
+    expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+    # Build and validate a golden model.
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(max_tokens=None)
+    layer.set_vocabulary(vocab_data)
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    output_dataset = model.predict(input_array)
+    self.assertAllEqual(output_dataset, expected_output)
+
+    # Save the model to disk.
+    output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+    model.save(output_path, save_format="tf")
+    loaded_model = saving.load_from_saved_model(
+        output_path, custom_objects={"IndexLookup": get_layer_class()})
+
+    # Ensure that the loaded model is unique (so that the save/load is real)
+    self.assertIsNot(model, loaded_model)
+
+    # Validate correctness of the new model.
+    new_output_dataset = loaded_model.predict(input_array)
+    self.assertAllEqual(new_output_dataset, expected_output)
+
+  def DISABLED_test_vocabulary_persistence_across_saving_with_tfidf(self):
+    vocab_data = ["earth", "wind", "and", "fire"]
+    tfidf_data = [.5, .25, .2, .125]
+    input_array = np.array([["earth", "wind", "and", "earth"],
+                            ["ohio", "fire", "earth", "michigan"]])
+
+    # pyformat: disable
+    # pylint: disable=bad-whitespace
+    expected_output = [[ 0,  1, .25, .2,    0],
+                       [.1, .5,   0,  0, .125]]
+    # pylint: enable=bad-whitespace
+    # pyformat: enable
+
+    # Build and validate a golden model.
+    input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+    layer = get_layer_class()(
+        max_tokens=5,
+        standardize=None,
+        split=None,
+        output_mode=index_lookup.TFIDF)
+    layer.set_vocabulary(vocab_data, df_data=tfidf_data, oov_df_value=.05)
+
+    int_data = layer(input_data)
+    model = keras.Model(inputs=input_data, outputs=int_data)
+    output_dataset = model.predict(input_array)
+    self.assertAllClose(output_dataset, expected_output)
+
+    # Save the model to disk.
+    output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+    model.save(output_path, save_format="tf")
+    loaded_model = saving.load_from_saved_model(
+        output_path, custom_objects={"IndexLookup": get_layer_class()})
+
+    # Ensure that the loaded model is unique (so that the save/load is real)
+    self.assertIsNot(model, loaded_model)
+
+    # Validate correctness of the new model.
+    new_output_dataset = loaded_model.predict(input_array)
+    self.assertAllClose(new_output_dataset, expected_output)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupCombinerTest(keras_parameterized.TestCase,
+                              preprocessing_test_utils.PreprocessingLayerTest):
+
+  def compare_text_accumulators(self, a, b, msg=None):
+    if a is None or b is None:
+      self.assertAllEqual(a, b, msg=msg)
+
+    self.assertAllEqual(a.count_dict, b.count_dict, msg=msg)
+
+  compare_accumulators = compare_text_accumulators
+
+  def update_accumulator(self, accumulator, data):
+    accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"])))
+
+    return accumulator
+
+  def test_combiner_api_compatibility_int_mode(self):
+    data = np.array([["earth", "wind", "and", "fire"],
+                     ["earth", "wind", "and", "michigan"]])
+    combiner = index_lookup._IndexLookupCombiner()
+    expected_accumulator_output = {
+        "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
+        "counts": np.array([2, 2, 2, 1, 1]),
+    }
+    expected_extract_output = {
+        "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
+    }
+    expected_accumulator = combiner._create_accumulator()
+    expected_accumulator = self.update_accumulator(expected_accumulator,
+                                                   expected_accumulator_output)
+    self.validate_accumulator_serialize_and_deserialize(combiner, data,
+                                                        expected_accumulator)
+    self.validate_accumulator_uniqueness(combiner, data)
+    self.validate_accumulator_extract(combiner, data, expected_extract_output)
+
+  # TODO(askerryryan): Add tests confirming equivalence to behavior of
+  # existing tf.keras.preprocessing.text.Tokenizer.
+  @parameterized.named_parameters(
+      {
+          "testcase_name":
+              "top_k_smaller_than_full_vocab",
+          "data":
+              np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+                        ["fire", "wind"]]),
+          "vocab_size":
+              3,
+          "expected_accumulator_output": {
+              "vocab": np.array(["wind", "fire", "earth", "and"]),
+              "counts": np.array([3, 2, 1, 1]),
+          },
+          "expected_extract_output": {
+              "vocab": np.array(["wind", "fire", "earth"]),
+          },
+      },
+      {
+          "testcase_name":
+              "top_k_larger_than_full_vocab",
+          "data":
+              np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+                        ["fire", "wind"]]),
+          "vocab_size":
+              10,
+          "expected_accumulator_output": {
+              "vocab": np.array(["wind", "fire", "earth", "and"]),
+              "counts": np.array([3, 2, 1, 1]),
+          },
+          "expected_extract_output": {
+              "vocab": np.array(["wind", "fire", "earth", "and"]),
+          },
+      },
+      {
+          "testcase_name":
+              "no_top_k",
+          "data":
+              np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+                        ["fire", "wind"]]),
+          "vocab_size":
+              None,
+          "expected_accumulator_output": {
+              "vocab": np.array(["wind", "fire", "earth", "and"]),
+              "counts": np.array([3, 2, 1, 1]),
+          },
+          "expected_extract_output": {
+              "vocab": np.array(["wind", "fire", "earth", "and"]),
+          },
+      },
+      {
+          "testcase_name": "single_element_per_row",
+          "data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]),
+          "vocab_size": 3,
+          "expected_accumulator_output": {
+              "vocab": np.array(["wind", "and", "earth", "fire"]),
+              "counts": np.array([2, 1, 1, 1]),
+          },
+          "expected_extract_output": {
+              "vocab": np.array(["wind", "fire", "earth"]),
+          },
+      },
+      # Which tokens are retained are based on global frequency, and thus are
+      # sensitive to frequency within a document. In contrast, because idf only
+      # considers the presence of a token in a document, it is insensitive
+      # to the frequency of the token within the document.
+      {
+          "testcase_name":
+              "retained_tokens_sensitive_to_within_document_frequency",
+          "data":
+              np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"],
+                        ["wind", "wind"], ["and", "michigan"]]),
+          "vocab_size":
+              3,
+          "expected_accumulator_output": {
+              "vocab": np.array(["wind", "earth", "fire", "and", "michigan"]),
+              "counts": np.array([4, 2, 2, 1, 1]),
+          },
+          "expected_extract_output": {
+              "vocab": np.array(["wind", "fire", "earth"]),
+          },
+      })
+  def test_combiner_computation(self, data, vocab_size,
+                                expected_accumulator_output,
+                                expected_extract_output):
+    combiner = index_lookup._IndexLookupCombiner(vocab_size=vocab_size)
+    expected_accumulator = combiner._create_accumulator()
+    expected_accumulator = self.update_accumulator(expected_accumulator,
+                                                   expected_accumulator_output)
+    self.validate_accumulator_computation(combiner, data, expected_accumulator)
+    self.validate_accumulator_extract(combiner, data, expected_extract_output)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
new file mode 100644
index 0000000..cb5691a
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow V1 version of the text vectorization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import base_preprocessing_layer_v1
+from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.ops.ragged import ragged_tensor_value
+
+
+class IndexLookup(index_lookup.IndexLookup,
+                  base_preprocessing_layer_v1.CombinerPreprocessingLayer):
+  """IndexLookup layer.
+
+  This layer translates a set of arbitray strings or integers into an integer
+  output via a table-based lookup, with optional out-of-vocabulary handling.
+
+  If desired, the user can call this layer's adapt() method on a data set.
+  When this layer is adapted, it will analyze the dataset, determine the
+  frequency of individual string or integer values, and create a vocabulary
+  from them. This vocabulary can have unlimited size or be capped, depending on
+  the configuration options for this layer; if there are more unique values in
+  the input than the maximum vocabulary size, the most frequent terms will be
+  used to create the vocabulary.
+
+  Attributes:
+    max_vocab_size: The maximum size of the vocabulary for this layer. If None,
+      there is no cap on the size of the vocabulary. Note that the vocabulary
+      does include OOV buckets, so the effective number of unique values in the
+      vocabulary is (max_vocab_size - num_oov_buckets) when this value is set.
+    num_oov_buckets: The number of out-of-vocabulary tokens to use; defaults to
+      1. If this value is more than 1, OOV inputs are hashed to determine their
+      OOV value; if this value is 0, passing an OOV input will result in a
+      runtime error.
+    reserve_zero: Whether to reserve the index '0', which has a special meaning
+      in the Keras masking system. If True, the output of this layer will be in
+      the range [1...max_vocab_size+1); if False, the output will be in the
+      range [0...max_vocab_size). Defaults to True.
+    mask_inputs: If True, input values of 0 (for integers) and "" (for strings)
+      will be treated as masked values and assigned an output value of 0. If
+      this option is set, reserve_zero must also be set. Defaults to False.
+  """
+
+  def _get_table_data(self):
+    keys, values = self._table.export()
+    np_keys = K.get_session().run(keys)
+    np_values = K.get_session().run(values)
+    return (np_keys, np_values)
+
+  def vocab_size(self):
+    return K.get_session().run(self._table.size())
+
+  def _clear_table(self):
+    keys, _ = self._table.export()
+    K.get_session().run(self._table.remove(keys))
+
+  def _insert_table_data(self, keys, values):
+    K.get_session().run(self._table.insert(keys, values))
+
+  def _to_numpy(self, data):
+    """Converts preprocessed inputs into numpy arrays."""
+    if isinstance(data, np.ndarray):
+      return data
+    session = K.get_session()
+    data = session.run(data)
+    if isinstance(data, ragged_tensor_value.RaggedTensorValue):
+      data = np.array(data.to_list())
+    return data
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
index a315df0..64fa210 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
@@ -32,13 +32,12 @@
 from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
 from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
 from tensorflow.python.keras.layers.preprocessing import categorical_encoding
+from tensorflow.python.keras.layers.preprocessing import index_lookup
 from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_string_ops
-from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import string_ops
-from tensorflow.python.ops.ragged import ragged_functional_ops
 from tensorflow.python.ops.ragged import ragged_string_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.util import compat
@@ -219,7 +218,7 @@
     # 'standardize' must be one of (None, LOWER_AND_STRIP_PUNCTUATION, callable)
     layer_utils.validate_string_arg(
         standardize,
-        allowable_strings=[LOWER_AND_STRIP_PUNCTUATION],
+        allowable_strings=(LOWER_AND_STRIP_PUNCTUATION),
         layer_name="TextVectorization",
         arg_name="standardize",
         allow_none=True,
@@ -228,7 +227,7 @@
     # 'split' must be one of (None, SPLIT_ON_WHITESPACE, callable)
     layer_utils.validate_string_arg(
         split,
-        allowable_strings=[SPLIT_ON_WHITESPACE],
+        allowable_strings=(SPLIT_ON_WHITESPACE),
         layer_name="TextVectorization",
         arg_name="split",
         allow_none=True,
@@ -237,7 +236,7 @@
     # 'output_mode' must be one of (None, INT, COUNT, BINARY, TFIDF)
     layer_utils.validate_string_arg(
         output_mode,
-        allowable_strings=[INT, COUNT, BINARY, TFIDF],
+        allowable_strings=(INT, COUNT, BINARY, TFIDF),
         layer_name="TextVectorization",
         arg_name="output_mode",
         allow_none=True)
@@ -303,24 +302,9 @@
             self._max_vocab_size, compute_idf=output_mode == TFIDF),
         **kwargs)
 
-    self._table = lookup_ops.MutableHashTable(
-        key_dtype=dtypes.string,
-        value_dtype=dtypes.int64,
-        default_value=self._oov_value,
-        name=(self._name + "_index_table"))
-
-    def fail(_):
-      raise NotImplementedError(
-          "Saving is not yet supported for TextVectorization layers.")
-    self._table._list_extra_dependencies_for_serialization = fail  # pylint: disable=protected-access
-
-    tracked_table = self._add_trackable(self._table, trainable=False)
-
-    # This is a workaround for summary() on this layer. Because the table is
-    # not mutable during training, the effective number of parameters (and so
-    # the weight shape) is 0; we add this as an attr so that the parameter
-    # counting code in the Model object doesn't throw an attribute error.
-    tracked_table.shape = tensor_shape.TensorShape((0,))
+    reserve_zero = output_mode in [None, INT]
+    self._index_lookup_layer = self._get_index_lookup_class()(
+        max_tokens=max_tokens, reserve_zero=reserve_zero, dtype=dtypes.string)
 
     # If this layer is configured for string or integer output, we do not
     # create a vectorization layer (as the output is not vectorized).
@@ -328,11 +312,11 @@
       return
 
     if max_tokens is not None and self._pad_to_max:
-      vectorize_max_tokens = max_tokens
+      max_elements = max_tokens
     else:
-      vectorize_max_tokens = None
+      max_elements = None
     self._vectorize_layer = self._get_vectorization_class()(
-        max_tokens=vectorize_max_tokens, output_mode=self._output_mode)
+        max_tokens=max_elements, output_mode=self._output_mode)
 
   # These are V1/V2 shim points. There are V1 implementations in the V1 class.
   def _get_vectorization_class(self):
@@ -342,31 +326,8 @@
     keys, values = self._table.export()
     return (keys.numpy(), values.numpy())
 
-  def _get_table_size(self):
-    return self._table.size().numpy()
-
-  def _clear_table(self):
-    if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
-        not self._pad_to_max):
-      raise RuntimeError(("When using TextVectorization in {mode} mode, the "
-                          "vocabulary cannot be changed after the layer is "
-                          "called.").format(mode=self._output_mode))
-    keys, _ = self._table.export()
-    self._table.remove(keys)
-    self._vocab_size = 0
-
-  def _insert_table_data(self, keys, values):
-    if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
-        not self._pad_to_max):
-      raise RuntimeError(("When using TextVectorization in {mode} mode, the "
-                          "vocabulary cannot be changed after the layer is "
-                          "called.").format(mode=self._output_mode))
-    if len(values) != len(keys):
-      raise RuntimeError("Size mismatch between values and key arrays. "
-                         "Keys had size %s, values had size %s." %
-                         (len(keys), len(values)))
-    self._table.insert(keys, values)
-    self._vocab_size += len(keys)
+  def _get_index_lookup_class(self):
+    return index_lookup.IndexLookup
 
   def _to_numpy(self, preprocessed_data):
     """Converts preprocessed inputs into numpy arrays."""
@@ -441,13 +402,7 @@
     super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
 
   def get_vocabulary(self):
-    if self._vocab_size == 0:
-      return []
-
-    keys, values = self._get_table_data()
-    # This is required because the MutableHashTable doesn't preserve insertion
-    # order, but we rely on the order of the array to assign indices.
-    return [x for _, x in sorted(zip(values, keys))]
+    return self._index_lookup_layer.get_vocabulary()
 
   def get_config(self):
     config = {
@@ -496,15 +451,33 @@
     Raises:
       ValueError: If there are too many inputs, the inputs do not match, or
         input data is missing.
+      RuntimeError: If the vocabulary cannot be set when this function is
+        called. This happens when "binary", "count", and "tfidf" modes,
+        if "pad_to_max_tokens" is False and the layer itself has already been
+        called.
     """
-    current_table_size = self._get_table_size()
-    total_vocab_size = len(vocab) + (current_table_size if append else 0)
-    if self._max_tokens is not None and total_vocab_size > self._max_vocab_size:
-      raise ValueError(
-          "Attempted to set a vocabulary larger than the maximum vocab size. "
-          "Passed vocab size is %s, max vocab size is %s. Note that the OOV "
-          "token is automatically added to the number of tokens." %
-          (total_vocab_size, self._max_vocab_size))
+    if self._output_mode != TFIDF and df_data is not None:
+      raise ValueError("df_data should only be set if output_mode is TFIDF. "
+                       "output_mode is %s." % self._output_mode)
+
+    if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
+        not self._pad_to_max):
+      raise RuntimeError(("When using TextVectorization in {mode} mode and "
+                          "pad_to_max_tokens is False, the vocabulary cannot "
+                          "be changed after the layer is "
+                          "called.").format(mode=self._output_mode))
+
+    current_table_size = self._index_lookup_layer.vocab_size()
+    self._index_lookup_layer.set_vocabulary(vocab, append)
+
+    # When doing raw or integer output, we don't have a Vectorize layer to
+    # manage. In this case, we can return directly.
+    if self._output_mode in [None, INT]:
+      return
+
+    if not self._pad_to_max or self._max_tokens is None:
+      num_tokens = self._index_lookup_layer.vocab_size() + self._reserved_values
+      self._vectorize_layer.set_num_elements(num_tokens)
 
     # We're only _really_ appending if the table_size is nonzero. This is
     # important for some sanity checks in tfidf mode (specifically, checking if
@@ -522,35 +495,7 @@
         raise ValueError("You must pass an oov_df_value the first time "
                          "'set_vocabulary' is called when output_mode is "
                          "TFIDF.")
-    else:
-      if df_data is not None:
-        raise ValueError("df_data should only be set if output_mode is TFIDF. "
-                         "output_mode is %s." % self._output_mode)
 
-    start_index = self._reserved_values + (
-        self._get_table_size() if append else 0)
-    values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
-
-    vocab = self._convert_to_ndarray(vocab)
-    self._assert_same_type(dtypes.string, vocab, "vocab")
-
-    values = self._convert_to_ndarray(values)
-    self._assert_same_type(dtypes.int64, values, "values")
-
-    if not append and self._vocab_size > 0:
-      self._clear_table()
-    self._insert_table_data(vocab, values)
-
-    # When doing raw or integer output, we don't have a Vectorize layer to
-    # manage. In this case, we can return directly.
-    if self._output_mode in [None, INT]:
-      return
-
-    if not self._pad_to_max or self._max_tokens is None:
-      num_tokens = total_vocab_size + self._reserved_values
-      self._vectorize_layer.set_num_elements(num_tokens)
-
-    if self._output_mode == TFIDF:
       df_data = self._convert_to_ndarray(df_data)
       if append:
         # The existing IDF data is stored in a Keras weight, so we can get it
@@ -584,9 +529,6 @@
           "dimension of the input array must be 1, got shape "
           "{}".format(input_shape))
 
-    # This handles a corner case where, if restored from weights or SavedModel,
-    # the layer might not have accurate vocab size information.
-    self._vocab_size = self._get_table_size()
     super(TextVectorization, self).build(input_shape)
 
   def _set_state_variables(self, updates):
@@ -646,13 +588,7 @@
     if self._output_mode is None:
       return inputs
 
-    # The table lookup ops don't natively support ragged tensors, so if we have
-    # a RT we need to use map_flat_values to look up every element.
-    if ragged_tensor.is_ragged(inputs):
-      indexed_data = ragged_functional_ops.map_flat_values(
-          self._table.lookup, inputs)
-    else:
-      indexed_data = self._table.lookup(inputs)
+    indexed_data = self._index_lookup_layer(inputs)
 
     if self._output_mode == INT:
       # Once we have the dense tensor, we can return it if we weren't given a
@@ -687,7 +623,7 @@
 
 
 # A note on this combiner: This contains functionality that will be extracted
-# into the Vectorization and Lookup combiner objects. At that point,
+# into the Vectorization and IndexLookup combiner objects. At that point,
 # TextVectorization can become a PreprocessingStage instead of a Layer and
 # this combiner can be retired. Until then, we leave this as is instead of
 # attempting a refactor of what will soon be deleted.
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
index 8c5b7f1..b869bee 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
@@ -23,6 +23,7 @@
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.engine import base_preprocessing_layer_v1
 from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1
+from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
 from tensorflow.python.keras.layers.preprocessing import text_vectorization
 from tensorflow.python.ops.ragged import ragged_tensor_value
 from tensorflow.python.util.tf_export import keras_export
@@ -82,37 +83,8 @@
   def _get_vectorization_class(self):
     return categorical_encoding_v1.CategoricalEncoding
 
-  def _get_table_data(self):
-    keys, values = self._table.export()
-    np_keys = K.get_session().run(keys)
-    np_values = K.get_session().run(values)
-    return (np_keys, np_values)
-
-  def _get_table_size(self):
-    return K.get_session().run(self._table.size())
-
-  def _clear_table(self):
-    if (self._output_mode in [
-        text_vectorization.BINARY, text_vectorization.COUNT,
-        text_vectorization.TFIDF
-    ] and self._called and not self._pad_to_max):
-      raise RuntimeError(("When using TextVectorization in {mode} mode, the "
-                          "vocabulary cannot be changed after the layer is "
-                          "called.").format(mode=self._output_mode))
-    keys, _ = self._table.export()
-    K.get_session().run(self._table.remove(keys))
-    self._vocab_size = 0
-
-  def _insert_table_data(self, keys, values):
-    if (self._output_mode in [
-        text_vectorization.BINARY, text_vectorization.COUNT,
-        text_vectorization.TFIDF
-    ] and self._called and not self._pad_to_max):
-      raise RuntimeError(("When using TextVectorization in {mode} mode, the "
-                          "vocabulary cannot be changed after the layer is "
-                          "called.").format(mode=self._output_mode))
-    K.get_session().run(self._table.insert(keys, values))
-    self._vocab_size += len(keys)
+  def _get_index_lookup_class(self):
+    return index_lookup_v1.IndexLookup
 
   def _to_numpy(self, data):
     """Converts preprocessed inputs into numpy arrays."""
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 41e473f..a386792 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -228,6 +228,7 @@
     layer.adapt(adapt_data)
 
   model = keras.models.Sequential()
+  model.add(keras.layers.Input(shape=input_shape[1:], dtype=input_dtype))
   model.add(layer)
   actual_output = model.predict(input_data)
   actual_output_shape = actual_output.shape
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index f2689d0..dcb42ab 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -86,7 +86,7 @@
   else:
     allowed_args = '`None`, ' if allow_none else ''
     allowed_args += 'a `Callable`, ' if allow_callables else ''
-    allowed_args += 'or one of the following values: %s' % allowable_strings
+    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
     raise ValueError(("%s's %s arg received an invalid value %s. " +
                       'Allowed values are %s.') %
                      (layer_name, arg_name, input_data, allowed_args))