Support a mask_value argument for preprocessing.hashing

This argument will allow specifying a value which should always map to the zero
index. For now, this will only be supported for a single tensor input as the
desired behavior when crossing multiple inputs is unclear.

PiperOrigin-RevId: 351904657
Change-Id: I8ae3fd88ef94f7b1244cd1a6da7adbc2a40dfef1
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d13473a..5046157 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -97,8 +97,10 @@
     ],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
         "//tensorflow/python:sparse_ops",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:string_ops",
@@ -106,8 +108,6 @@
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python/keras/engine",
-        "//tensorflow/python/keras/utils:tf_utils",
-        "//tensorflow/python/ops/ragged:ragged_functional_ops",
         "//tensorflow/python/ops/ragged:ragged_tensor",
         "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py
index cef4183..925e1ca 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing.py
@@ -28,11 +28,11 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras.engine import base_preprocessing_layer
-from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import string_ops
-from tensorflow.python.ops.ragged import ragged_functional_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.util.tf_export import keras_export
 
@@ -71,6 +71,19 @@
            [1],
            [2]])>
 
+  Example (FarmHash64) with a mask value:
+
+  >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
+  ...    mask_value='')
+  >>> inp = [['A'], ['B'], [''], ['C'], ['D']]
+  >>> layer(inp)
+  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
+    array([[1],
+           [1],
+           [0],
+           [2],
+           [2]])>
+
 
   Example (FarmHash64) with list of inputs:
   >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
@@ -114,7 +127,12 @@
   Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)
 
   Args:
-    num_bins: Number of hash bins.
+    num_bins: Number of hash bins. Note that this includes the `mask_value` bin,
+      so the effective number of bins is `(num_bins - 1)` if `mask_value` is
+      set.
+    mask_value: A value that represents masked inputs, which are mapped to
+      index 0. Defaults to None, meaning no mask term will be added and the
+      hashing will start at index 0.
     salt: A single unsigned integer or None.
       If passed, the hash function used will be SipHash64, with these values
       used as an additional input (known as a "salt" in cryptography).
@@ -134,12 +152,13 @@
 
   """
 
-  def __init__(self, num_bins, salt=None, name=None, **kwargs):
+  def __init__(self, num_bins, mask_value=None, salt=None, name=None, **kwargs):
     if num_bins is None or num_bins <= 0:
       raise ValueError('`num_bins` cannot be `None` or non-positive values.')
     super(Hashing, self).__init__(name=name, **kwargs)
     base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True)
     self.num_bins = num_bins
+    self.mask_value = mask_value
     self.strong_hash = True if salt is not None else False
     if salt is not None:
       if isinstance(salt, (tuple, list)) and len(salt) == 2:
@@ -170,39 +189,22 @@
     inputs = self._preprocess_inputs(inputs)
     if isinstance(inputs, (tuple, list)):
       return self._process_input_list(inputs)
-    else:
-      return self._process_single_input(inputs)
-
-  def _process_single_input(self, inputs):
-    # Converts integer inputs to string.
-    if inputs.dtype.is_integer:
-      if isinstance(inputs, sparse_tensor.SparseTensor):
-        inputs = sparse_tensor.SparseTensor(
-            indices=inputs.indices,
-            values=string_ops.as_string(inputs.values),
-            dense_shape=inputs.dense_shape)
-      else:
-        inputs = string_ops.as_string(inputs)
-    str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
-    if tf_utils.is_ragged(inputs):
-      return ragged_functional_ops.map_flat_values(
-          str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
     elif isinstance(inputs, sparse_tensor.SparseTensor):
-      sparse_values = inputs.values
-      sparse_hashed_values = str_to_hash_bucket(
-          sparse_values, self.num_bins, name='hash')
       return sparse_tensor.SparseTensor(
           indices=inputs.indices,
-          values=sparse_hashed_values,
+          values=self._hash_values_to_bins(inputs.values),
           dense_shape=inputs.dense_shape)
-    else:
-      return str_to_hash_bucket(inputs, self.num_bins, name='hash')
+    return self._hash_values_to_bins(inputs)
 
   def _process_input_list(self, inputs):
     # TODO(momernick): support ragged_cross_hashed with corrected fingerprint
     # and siphash.
     if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
       raise ValueError('Hashing with ragged input is not supported yet.')
+    if self.mask_value is not None:
+      raise ValueError(
+          'Cross hashing with a mask_value is not supported yet, mask_value is '
+          '{}.'.format(self.mask_value))
     sparse_inputs = [
         inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
     ]
@@ -226,6 +228,24 @@
       return sparse_ops.sparse_tensor_to_dense(sparse_out)
     return sparse_out
 
+  def _hash_values_to_bins(self, values):
+    """Converts a non-sparse tensor of values to bin indices."""
+    str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
+    num_available_bins = self.num_bins
+    mask = None
+    # If mask_value is set, the zeroth bin is reserved for it.
+    if self.mask_value is not None and num_available_bins > 1:
+      num_available_bins -= 1
+      mask = math_ops.equal(values, self.mask_value)
+    # Convert all values to strings before hashing.
+    if values.dtype.is_integer:
+      values = string_ops.as_string(values)
+    values = str_to_hash_bucket(values, num_available_bins, name='hash')
+    if mask is not None:
+      values = math_ops.add(values, array_ops.ones_like(values))
+      values = array_ops.where(mask, array_ops.zeros_like(values), values)
+    return values
+
   def _get_string_to_hash_bucket_fn(self):
     """Returns the string_to_hash_bucket op to use based on `hasher_key`."""
     # string_to_hash_bucket_fast uses FarmHash64 as hash function.
@@ -274,6 +294,10 @@
     return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
 
   def get_config(self):
-    config = {'num_bins': self.num_bins, 'salt': self.salt}
+    config = {
+        'num_bins': self.num_bins,
+        'salt': self.salt,
+        'mask_value': self.mask_value,
+    }
     base_config = super(Hashing, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
index 58592b8..712a78e 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
@@ -51,6 +51,27 @@
     # Assert equal for hashed output that should be true on all platforms.
     self.assertAllClose([[0], [0], [1], [0], [0]], output)
 
+  def test_hash_dense_input_mask_value_farmhash(self):
+    empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+    omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+    inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
+                      ['skywalker']])
+    empty_mask_output = empty_mask_layer(inp)
+    omar_mask_output = omar_mask_layer(inp)
+    # Outputs should be one more than test_hash_dense_input_farmhash (the zeroth
+    # bin is now reserved for masks).
+    self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output)
+    # 'omar' should map to 0.
+    self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)
+
+  def test_hash_dense_multi_inputs_mask_value_farmhash(self):
+    layer = hashing.Hashing(num_bins=3, mask_value='omar')
+    inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
+                        ['skywalker']])
+    inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
+    with self.assertRaisesRegex(ValueError, 'not supported yet'):
+      _ = layer([inp_1, inp_2])
+
   def test_hash_dense_multi_inputs_farmhash(self):
     layer = hashing.Hashing(num_bins=2)
     inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
@@ -135,6 +156,24 @@
     self.assertAllClose(indices, output.indices)
     self.assertAllClose([0, 0, 1, 0, 0], output.values)
 
+  def test_hash_sparse_input_mask_value_farmhash(self):
+    empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+    omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+    indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
+    inp = sparse_tensor.SparseTensor(
+        indices=indices,
+        values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
+        dense_shape=[3, 2])
+    empty_mask_output = empty_mask_layer(inp)
+    omar_mask_output = omar_mask_layer(inp)
+    self.assertAllClose(indices, omar_mask_output.indices)
+    self.assertAllClose(indices, empty_mask_output.indices)
+    # Outputs should be one more than test_hash_sparse_input_farmhash (the
+    # zeroth bin is now reserved for masks).
+    self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values)
+    # 'omar' should map to 0.
+    self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)
+
   def test_hash_sparse_multi_inputs_farmhash(self):
     layer = hashing.Hashing(num_bins=2)
     indices = [[0, 0], [1, 0], [2, 0]]
@@ -217,6 +256,22 @@
     model = training.Model(inputs=inp_t, outputs=out_t)
     self.assertAllClose(out_data, model.predict(inp_data))
 
+  def test_hash_ragged_input_mask_value(self):
+    empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+    omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+    inp_data = ragged_factory_ops.constant(
+        [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
+        dtype=dtypes.string)
+    empty_mask_output = empty_mask_layer(inp_data)
+    omar_mask_output = omar_mask_layer(inp_data)
+    # Outputs should be one more than test_hash_ragged_string_input_farmhash
+    # (the zeroth bin is now reserved for masks).
+    expected_output = [[1, 1, 2, 1], [2, 1, 1]]
+    self.assertAllClose(expected_output, empty_mask_output)
+    # 'omar' should map to 0.
+    expected_output = [[0, 1, 2, 1], [2, 1, 1]]
+    self.assertAllClose(expected_output, omar_mask_output)
+
   def test_hash_ragged_string_multi_inputs_farmhash(self):
     layer = hashing.Hashing(num_bins=2)
     inp_data_1 = ragged_factory_ops.constant(
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
index 2c9af8c..dbd8d4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
@@ -130,7 +130,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "adapt"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
index 2c9af8c..dbd8d4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
@@ -130,7 +130,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+    argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "adapt"