ragged support for category crossing.
PiperOrigin-RevId: 293743976
Change-Id: I5c023bee8a944a812abf7bccd360f1efc53e7275
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical.py b/tensorflow/python/keras/layers/preprocessing/categorical.py
index c1aa527..8e0de00 100644
--- a/tensorflow/python/keras/layers/preprocessing/categorical.py
+++ b/tensorflow/python/keras/layers/preprocessing/categorical.py
@@ -101,11 +101,26 @@
# TODO(tanzheny): Consider making seperator configurable.
if depth is not None:
raise NotImplementedError('`depth` is not supported yet.')
+ super(CategoryCrossing, self).__init__(name=name, **kwargs)
self.num_bins = num_bins
self.depth = depth
- super(CategoryCrossing, self).__init__(name=name, **kwargs)
+ self._supports_ragged_inputs = True
def call(self, inputs):
+ # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will
+ # cause kernel failure. Investigate and find a more efficient implementation
+ if all([ragged_tensor.is_ragged(inp) for inp in inputs]):
+ inputs = [inp.to_sparse() if ragged_tensor.is_ragged(inp) else inp
+ for inp in inputs]
+ if self.num_bins is not None:
+ output = sparse_ops.sparse_cross_hashed(
+ inputs, num_buckets=self.num_bins)
+ else:
+ output = sparse_ops.sparse_cross(inputs)
+ return ragged_tensor.RaggedTensor.from_sparse(output)
+ if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
+ raise ValueError('Inputs must be either all `RaggedTensor`, or none of '
+ 'them should be `RaggedTensor`, got {}'.format(inputs))
sparse_output = False
if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
sparse_output = True
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_test.py b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
index 4267b0e..a519c88 100644
--- a/tensorflow/python/keras/layers/preprocessing/categorical_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
@@ -73,6 +73,39 @@
self.assertEqual(output.values.numpy().max(), 1)
self.assertEqual(output.values.numpy().min(), 0)
+ def test_crossing_hashed_ragged_inputs(self):
+ layer = categorical.CategoryCrossing(num_bins=2)
+ inputs_0 = ragged_factory_ops.constant(
+ [['omar', 'skywalker'], ['marlo']],
+ dtype=dtypes.string)
+ inputs_1 = ragged_factory_ops.constant(
+ [['a'], ['b']],
+ dtype=dtypes.string)
+ out_data = layer([inputs_0, inputs_1])
+ expected_output = [[0, 0], [0]]
+ self.assertAllClose(expected_output, out_data)
+ inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
+ inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
+ out_t = layer([inp_0_t, inp_1_t])
+ model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
+ self.assertAllClose(expected_output, model.predict([inputs_0, inputs_1]))
+
+ non_hashed_layer = categorical.CategoryCrossing()
+ out_t = non_hashed_layer([inp_0_t, inp_1_t])
+ model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
+ expected_output = [[b'omar_X_a', b'skywalker_X_a'], [b'marlo_X_b']]
+ self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1]))
+
+ def test_invalid_mixed_sparse_and_ragged_input(self):
+ with self.assertRaises(ValueError):
+ layer = categorical.CategoryCrossing(num_bins=2)
+ inputs_0 = ragged_factory_ops.constant(
+ [['omar'], ['marlo']],
+ dtype=dtypes.string)
+ inputs_1 = sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
+ layer([inputs_0, inputs_1])
+
def test_crossing_with_dense_inputs(self):
layer = categorical.CategoryCrossing()
inputs_0 = np.asarray([[1, 2]])