Fix bug that could cause map_fn to produce incorrect results (rather than an error)
when mapping over a ragged tensor with an inappropriate fn_output_signature.  (Note: there are cases where the default value for fn_output_signature is not appropriate, so the user needs to explicitly specify the correct output signature.)

PiperOrigin-RevId: 387606546
Change-Id: Ib4ea27b9634e6ab413f211cfe809a69a90f0e2cd
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
index d9993bb..c481d90 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
@@ -174,7 +174,23 @@
   auto output_values_flat =
       output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
   int values_index = 0;
+
+  TensorShape expected_value_shape = component_values_shape;
+  expected_value_shape.RemoveDim(0);
+
   for (int i = 0; i < ragged_components.size(); i++) {
+    // Check that the flat_values tensor shape is compatible.
+    TensorShape value_shape = ragged_components[i].values().shape();
+    value_shape.RemoveDim(0);
+    if (value_shape != expected_value_shape) {
+      return errors::InvalidArgument(
+          "All flat_values must have compatible shapes.  Shape at index 0: ",
+          expected_value_shape, ".  Shape at index ", i, ": ", value_shape,
+          ".  If you are using tf.map_fn, then you may need to specify an "
+          "explicit fn_output_signature with appropriate ragged_rank, and/or "
+          "convert output tensors to RaggedTensors.");
+    }
+
     auto component_values_flat =
         ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
     int num_inner_elements = ragged_components[i].values().NumElements();
diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
index bead492..ace724a 100644
--- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
@@ -21,9 +21,11 @@
 import numpy as np
 
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import map_fn as map_fn_lib
 from tensorflow.python.ops import math_ops as mo
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
@@ -309,6 +311,27 @@
     )
     self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
 
+  def testRaggedMapWithIncorrectFnOutputSignature(self):
+    x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
+    with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                'All flat_values must have compatible shapes'):
+      y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn(lambda y: r, r), x)
+      self.evaluate(y)
+
+  def testNestedRaggedMapWithFnOutputSignature(self):
+    ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32)
+    ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
+
+    x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
+    # pylint: disable=g-long-lambda
+    y = map_fn_lib.map_fn(
+        lambda r: map_fn_lib.map_fn(
+            lambda y: r, r, fn_output_signature=ragged1d),
+        x,
+        fn_output_signature=ragged2d)
+    expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]]
+    self.assertAllEqual(y, expected)
+
 
 if __name__ == '__main__':
   googletest.main()