Support a mask_value argument for preprocessing.hashing
This argument will allow specifying a value which should always map to the zero
index. For now, this will only be supported for a single tensor input as the
desired behavior when crossing multiple inputs is unclear.
PiperOrigin-RevId: 351904657
Change-Id: I8ae3fd88ef94f7b1244cd1a6da7adbc2a40dfef1
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d13473a..5046157 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -97,8 +97,10 @@
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
@@ -106,8 +108,6 @@
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python/keras/engine",
- "//tensorflow/python/keras/utils:tf_utils",
- "//tensorflow/python/ops/ragged:ragged_functional_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/util:tf_export",
"//third_party/py/numpy",
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py
index cef4183..925e1ca 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing.py
@@ -28,11 +28,11 @@
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras.engine import base_preprocessing_layer
-from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import keras_export
@@ -71,6 +71,19 @@
[1],
[2]])>
+ Example (FarmHash64) with a mask value:
+
+ >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
+ ... mask_value='')
+ >>> inp = [['A'], ['B'], [''], ['C'], ['D']]
+ >>> layer(inp)
+ <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
+ array([[1],
+ [1],
+ [0],
+ [2],
+ [2]])>
+
Example (FarmHash64) with list of inputs:
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
@@ -114,7 +127,12 @@
Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)
Args:
- num_bins: Number of hash bins.
+ num_bins: Number of hash bins. Note that this includes the `mask_value` bin,
+ so the effective number of bins is `(num_bins - 1)` if `mask_value` is
+ set.
+ mask_value: A value that represents masked inputs, which are mapped to
+ index 0. Defaults to None, meaning no mask term will be added and the
+ hashing will start at index 0.
salt: A single unsigned integer or None.
If passed, the hash function used will be SipHash64, with these values
used as an additional input (known as a "salt" in cryptography).
@@ -134,12 +152,13 @@
"""
- def __init__(self, num_bins, salt=None, name=None, **kwargs):
+ def __init__(self, num_bins, mask_value=None, salt=None, name=None, **kwargs):
if num_bins is None or num_bins <= 0:
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
super(Hashing, self).__init__(name=name, **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Hashing').set(True)
self.num_bins = num_bins
+ self.mask_value = mask_value
self.strong_hash = True if salt is not None else False
if salt is not None:
if isinstance(salt, (tuple, list)) and len(salt) == 2:
@@ -170,39 +189,22 @@
inputs = self._preprocess_inputs(inputs)
if isinstance(inputs, (tuple, list)):
return self._process_input_list(inputs)
- else:
- return self._process_single_input(inputs)
-
- def _process_single_input(self, inputs):
- # Converts integer inputs to string.
- if inputs.dtype.is_integer:
- if isinstance(inputs, sparse_tensor.SparseTensor):
- inputs = sparse_tensor.SparseTensor(
- indices=inputs.indices,
- values=string_ops.as_string(inputs.values),
- dense_shape=inputs.dense_shape)
- else:
- inputs = string_ops.as_string(inputs)
- str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
- if tf_utils.is_ragged(inputs):
- return ragged_functional_ops.map_flat_values(
- str_to_hash_bucket, inputs, num_buckets=self.num_bins, name='hash')
elif isinstance(inputs, sparse_tensor.SparseTensor):
- sparse_values = inputs.values
- sparse_hashed_values = str_to_hash_bucket(
- sparse_values, self.num_bins, name='hash')
return sparse_tensor.SparseTensor(
indices=inputs.indices,
- values=sparse_hashed_values,
+ values=self._hash_values_to_bins(inputs.values),
dense_shape=inputs.dense_shape)
- else:
- return str_to_hash_bucket(inputs, self.num_bins, name='hash')
+ return self._hash_values_to_bins(inputs)
def _process_input_list(self, inputs):
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint
# and siphash.
if any(isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs):
raise ValueError('Hashing with ragged input is not supported yet.')
+ if self.mask_value is not None:
+ raise ValueError(
+ 'Cross hashing with a mask_value is not supported yet, mask_value is '
+ '{}.'.format(self.mask_value))
sparse_inputs = [
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
]
@@ -226,6 +228,24 @@
return sparse_ops.sparse_tensor_to_dense(sparse_out)
return sparse_out
+ def _hash_values_to_bins(self, values):
+ """Converts a non-sparse tensor of values to bin indices."""
+ str_to_hash_bucket = self._get_string_to_hash_bucket_fn()
+ num_available_bins = self.num_bins
+ mask = None
+ # If mask_value is set, the zeroth bin is reserved for it.
+ if self.mask_value is not None and num_available_bins > 1:
+ num_available_bins -= 1
+ mask = math_ops.equal(values, self.mask_value)
+ # Convert all values to strings before hashing.
+ if values.dtype.is_integer:
+ values = string_ops.as_string(values)
+ values = str_to_hash_bucket(values, num_available_bins, name='hash')
+ if mask is not None:
+ values = math_ops.add(values, array_ops.ones_like(values))
+ values = array_ops.where(mask, array_ops.zeros_like(values), values)
+ return values
+
def _get_string_to_hash_bucket_fn(self):
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
@@ -274,6 +294,10 @@
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
def get_config(self):
- config = {'num_bins': self.num_bins, 'salt': self.salt}
+ config = {
+ 'num_bins': self.num_bins,
+ 'salt': self.salt,
+ 'mask_value': self.mask_value,
+ }
base_config = super(Hashing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
index 58592b8..712a78e 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py
@@ -51,6 +51,27 @@
# Assert equal for hashed output that should be true on all platforms.
self.assertAllClose([[0], [0], [1], [0], [0]], output)
+ def test_hash_dense_input_mask_value_farmhash(self):
+ empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+ omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+ inp = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
+ ['skywalker']])
+ empty_mask_output = empty_mask_layer(inp)
+ omar_mask_output = omar_mask_layer(inp)
+ # Outputs should be one more than test_hash_dense_input_farmhash (the zeroth
+ # bin is now reserved for masks).
+ self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output)
+ # 'omar' should map to 0.
+ self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)
+
+ def test_hash_dense_multi_inputs_mask_value_farmhash(self):
+ layer = hashing.Hashing(num_bins=3, mask_value='omar')
+ inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
+ ['skywalker']])
+ inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
+ with self.assertRaisesRegex(ValueError, 'not supported yet'):
+ _ = layer([inp_1, inp_2])
+
def test_hash_dense_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
@@ -135,6 +156,24 @@
self.assertAllClose(indices, output.indices)
self.assertAllClose([0, 0, 1, 0, 0], output.values)
+ def test_hash_sparse_input_mask_value_farmhash(self):
+ empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+ omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+ indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
+ inp = sparse_tensor.SparseTensor(
+ indices=indices,
+ values=['omar', 'stringer', 'marlo', 'wire', 'skywalker'],
+ dense_shape=[3, 2])
+ empty_mask_output = empty_mask_layer(inp)
+ omar_mask_output = omar_mask_layer(inp)
+ self.assertAllClose(indices, omar_mask_output.indices)
+ self.assertAllClose(indices, empty_mask_output.indices)
+ # Outputs should be one more than test_hash_sparse_input_farmhash (the
+ # zeroth bin is now reserved for masks).
+ self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values)
+ # 'omar' should map to 0.
+ self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)
+
def test_hash_sparse_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
indices = [[0, 0], [1, 0], [2, 0]]
@@ -217,6 +256,22 @@
model = training.Model(inputs=inp_t, outputs=out_t)
self.assertAllClose(out_data, model.predict(inp_data))
+ def test_hash_ragged_input_mask_value(self):
+ empty_mask_layer = hashing.Hashing(num_bins=3, mask_value='')
+ omar_mask_layer = hashing.Hashing(num_bins=3, mask_value='omar')
+ inp_data = ragged_factory_ops.constant(
+ [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
+ dtype=dtypes.string)
+ empty_mask_output = empty_mask_layer(inp_data)
+ omar_mask_output = omar_mask_layer(inp_data)
+ # Outputs should be one more than test_hash_ragged_string_input_farmhash
+ # (the zeroth bin is now reserved for masks).
+ expected_output = [[1, 1, 2, 1], [2, 1, 1]]
+ self.assertAllClose(expected_output, empty_mask_output)
+ # 'omar' should map to 0.
+ expected_output = [[0, 1, 2, 1], [2, 1, 1]]
+ self.assertAllClose(expected_output, omar_mask_output)
+
def test_hash_ragged_string_multi_inputs_farmhash(self):
layer = hashing.Hashing(num_bins=2)
inp_data_1 = ragged_factory_ops.constant(
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
index 2c9af8c..dbd8d4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
@@ -130,7 +130,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "adapt"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
index 2c9af8c..dbd8d4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt
@@ -130,7 +130,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'num_bins\', \'mask_value\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "adapt"