ragged support for category crossing.

PiperOrigin-RevId: 293743976
Change-Id: I5c023bee8a944a812abf7bccd360f1efc53e7275
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical.py b/tensorflow/python/keras/layers/preprocessing/categorical.py
index c1aa527..8e0de00 100644
--- a/tensorflow/python/keras/layers/preprocessing/categorical.py
+++ b/tensorflow/python/keras/layers/preprocessing/categorical.py
@@ -101,11 +101,26 @@
     # TODO(tanzheny): Consider making seperator configurable.
     if depth is not None:
       raise NotImplementedError('`depth` is not supported yet.')
+    super(CategoryCrossing, self).__init__(name=name, **kwargs)
     self.num_bins = num_bins
     self.depth = depth
-    super(CategoryCrossing, self).__init__(name=name, **kwargs)
+    self._supports_ragged_inputs = True
 
   def call(self, inputs):
+    # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will
+    # cause kernel failure. Investigate and find a more efficient implementation
+    if all([ragged_tensor.is_ragged(inp) for inp in inputs]):
+      inputs = [inp.to_sparse() if ragged_tensor.is_ragged(inp) else inp
+                for inp in inputs]
+      if self.num_bins is not None:
+        output = sparse_ops.sparse_cross_hashed(
+            inputs, num_buckets=self.num_bins)
+      else:
+        output = sparse_ops.sparse_cross(inputs)
+      return ragged_tensor.RaggedTensor.from_sparse(output)
+    if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
+      raise ValueError('Inputs must be either all `RaggedTensor`, or none of '
+                       'them should be `RaggedTensor`, got {}'.format(inputs))
     sparse_output = False
     if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
       sparse_output = True
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_test.py b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
index 4267b0e..a519c88 100644
--- a/tensorflow/python/keras/layers/preprocessing/categorical_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
@@ -73,6 +73,39 @@
     self.assertEqual(output.values.numpy().max(), 1)
     self.assertEqual(output.values.numpy().min(), 0)
 
+  def test_crossing_hashed_ragged_inputs(self):
+    layer = categorical.CategoryCrossing(num_bins=2)
+    inputs_0 = ragged_factory_ops.constant(
+        [['omar', 'skywalker'], ['marlo']],
+        dtype=dtypes.string)
+    inputs_1 = ragged_factory_ops.constant(
+        [['a'], ['b']],
+        dtype=dtypes.string)
+    out_data = layer([inputs_0, inputs_1])
+    expected_output = [[0, 0], [0]]
+    self.assertAllClose(expected_output, out_data)
+    inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
+    inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
+    out_t = layer([inp_0_t, inp_1_t])
+    model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
+    self.assertAllClose(expected_output, model.predict([inputs_0, inputs_1]))
+
+    non_hashed_layer = categorical.CategoryCrossing()
+    out_t = non_hashed_layer([inp_0_t, inp_1_t])
+    model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
+    expected_output = [[b'omar_X_a', b'skywalker_X_a'], [b'marlo_X_b']]
+    self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1]))
+
+  def test_invalid_mixed_sparse_and_ragged_input(self):
+    with self.assertRaises(ValueError):
+      layer = categorical.CategoryCrossing(num_bins=2)
+      inputs_0 = ragged_factory_ops.constant(
+          [['omar'], ['marlo']],
+          dtype=dtypes.string)
+      inputs_1 = sparse_tensor.SparseTensor(
+          indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
+      layer([inputs_0, inputs_1])
+
   def test_crossing_with_dense_inputs(self):
     layer = categorical.CategoryCrossing()
     inputs_0 = np.asarray([[1, 2]])