PFor: Support TensorLists in the while_loop converter when the condition is pfor-loop-variant

Since they use internal stacking, they need to be accumulated differently.

PiperOrigin-RevId: 382784481
Change-Id: I1628178d61e0f7a9158b0ee57d37244e45d93297
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index 465d317..db28ad1 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1370,7 +1370,6 @@
 
     self._test_loop_fn(loop_fn, 2)
 
-  @test_util.enable_control_flow_v2
   def test_tensor_list_reserve_while_loop(self):
     # Here a loop invariant TensorList is captured by a while_loop, which then
     # performs loop dependent operations on it, resulting in a loop variant
@@ -1378,6 +1377,8 @@
     # while_loop.
     # We handle this particular case by forcing vectorization of
     # TensorListReserve operation.
+    v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled()
+    control_flow_v2_toggles.enable_control_flow_v2()
 
     def loop_fn(i):
       handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
@@ -1387,32 +1388,8 @@
       return list_ops.tensor_list_stack(out_handle, dtypes.int32)
 
     self._test_loop_fn(loop_fn, 2)
-
-  @test_util.enable_control_flow_v2
-  def test_tensor_list_while_loop_stacked_cond_stacked_list(self):
-
-    def loop_fn(i):
-      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], [])
-      _, out_handle = control_flow_ops.while_loop(
-          lambda j, _: j < i,
-          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
-          (0, handle))
-      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
-
-    self._test_loop_fn(loop_fn, 5)
-
-  @test_util.enable_control_flow_v2
-  def test_tensor_list_while_loop_stacked_cond_unstacked_list(self):
-
-    def loop_fn(i):
-      handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], [])
-      _, out_handle = control_flow_ops.while_loop(
-          lambda j, _: j < i,
-          lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
-          (0, handle))
-      return list_ops.tensor_list_stack(out_handle, dtypes.int32)
-
-    self._test_loop_fn(loop_fn, 5)
+    if not v2_enabled:
+      control_flow_v2_toggles.disable_control_flow_v2()
 
   def test_tensor_list_addn_already_stacked(self):
 
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index ae3f13b..e5f3f4c 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -88,23 +88,6 @@
   return handle_data.shape_and_type
 
 
-def _variant_type_id(t):
-  """Returns the full_type_pb2 type of `t`, or None if it is not available."""
-  if t.dtype != dtypes.variant:
-    return None
-  shapes_and_types = _variant_handle_data(t)
-  if shapes_and_types is None or not shapes_and_types:
-    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
-    # make this an error instead of assuming TensorLists have handle data.
-    return None  # Presumed not a TensorList/Optional
-  return shapes_and_types[0].type.type_id
-
-
-_INTERNAL_STACKING_TYPE_IDS = (
-    full_type_pb2.TFT_ARRAY,
-    full_type_pb2.TFT_OPTIONAL)
-
-
 def _is_variant_with_internal_stacking(t):
   """Identifies variant tensors which pfor always maintains as scalars.
 
@@ -116,8 +99,15 @@
   Returns:
     True if `t` is a TensorList/Optional, False not, None if unknown.
   """
-  type_id = _variant_type_id(t)
-  return type_id in _INTERNAL_STACKING_TYPE_IDS
+  if t.dtype != dtypes.variant:
+    return False
+  shapes_and_types = _variant_handle_data(t)
+  if shapes_and_types is None or not shapes_and_types:
+    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
+    # make this an error instead of assuming TensorLists have handle data.
+    return None  # Presumed not a TensorList/Optional
+  type_id = shapes_and_types[0].type.type_id
+  return type_id in (full_type_pb2.TFT_ARRAY, full_type_pb2.TFT_OPTIONAL)
 
 
 def _parse_variant_shapes_and_types(t):
@@ -4536,60 +4526,11 @@
     with ops.name_scope("while_init"):
       for inp in self._pfor_input.inputs:
         inputs.append(inp.t)
-        variant_type_id = _variant_type_id(inp.t)
-        if variant_type_id in _INTERNAL_STACKING_TYPE_IDS:
-          if variant_type_id != full_type_pb2.TFT_ARRAY:
-            raise NotImplementedError(
-                ("While loop conversion is only supported for TensorLists. Got "
-                 "another variant {}, probably an optional. Please file a bug.")
-                .format(inp.t))
-          # For TensorLists, the input format is:
-          #
-          #   List[user_list_len, Tensor[loop_len, ...]]
-          #
-          # rather than the usual
-          #
-          #   Tensor[loop_len, ...]
-          #
-          # The body of the loop will take and return lists in this "internal
-          # vectorization" format, so we want to keep it that way as much as
-          # possible. We'll accumulate finished iterations (only relevant for
-          # pfor-loop-variant while_loop conditions) in an accumulator with
-          # type:
-          #
-          #   List[user_list_len, List[loop_len, Tensor[...]]]
-          #
-          # This means that each while_loop iteration, we'll iterate over the
-          # length of the TensorList, dividing done/remaining pfor loop indices
-          # and scattering the done indices into the inner nested list of the
-          # accumulator.
-          element_shape = list_ops.tensor_list_element_shape(
-              inp.t, dtypes.int32)[1:]
-          dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype
-
-          def _init_loop_body(index, output_ta):
-            output_ta = output_ta.write(
-                index,
-                list_ops.tensor_list_reserve(element_shape, loop_len, dtype))
-            return index + 1, output_ta
-
-          length = list_ops.tensor_list_length(inp.t)
-          output_ta = tensor_array_ops.TensorArray(
-            inp.t.dtype,  # Variant; this is a nested TensorList
-            size=length,
-            dynamic_size=True,
-            infer_shape=False)
-          _, output_ta = control_flow_ops.while_loop(
-              lambda index, _: index < length,
-              _init_loop_body,
-              [0, output_ta])
-        else:
-          output_ta = tensor_array_ops.TensorArray(
+        output_tas.append(tensor_array_ops.TensorArray(
             inp.t.dtype,
             size=loop_len,
             dynamic_size=False,
-            infer_shape=True)
-        output_tas.append(output_ta)
+            infer_shape=True))
     # See documentation for __call__ for the structure of init_values.
     indices = (
         math_ops.range(self._pfor.loop_len_vector[0])
@@ -4617,51 +4558,21 @@
     new_output_tas = []
     for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
       pass_through = i in self._body_pass_through_indices
-      if not pass_through and  _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
-        shape_and_type = _parse_variant_shapes_and_types(inp)[0]
-        element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32)
-        user_list_len = list_ops.tensor_list_length(inp)
-
-        def _split_vectorized_ta_element(index, new_inp, new_out_ta):
-          elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype,
-                                               element_shape)
-          if stacked:
-            done_elem, new_elem = data_flow_ops.dynamic_partition(
-                elem, conditions_int, 2)
-            new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem)
-          else:
-            done_elem = _stack(elem, [array_ops.size(done_indices)]).t
-          done_accum = new_out_ta.read(index)
-          done_accum = list_ops.tensor_list_scatter(
-              tensor=done_elem, indices=done_indices, input_handle=done_accum)
-          new_out_ta = new_out_ta.write(index, done_accum)
-          return index + 1, new_inp, new_out_ta
-
-        length = list_ops.tensor_list_length(inp)
-        new_inp = list_ops.tensor_list_reserve(
-            tensor_shape.TensorShape([None])
-            + tensor_shape.TensorShape(shape_and_type.shape)[1:],
-            user_list_len, shape_and_type.dtype)
-        _, new_inp, out_ta = control_flow_ops.while_loop(
-            lambda index, unused_new_inp, unused_new_out_ta: index < length,
-            _split_vectorized_ta_element,
-            [0, new_inp, output_tas[i]])
+      # Partition the inputs.
+      if stacked:
+        done_inp, new_inp = data_flow_ops.dynamic_partition(
+            inp, conditions_int, 2)
       else:
-        # Partition the inputs.
-        if stacked:
-          done_inp, new_inp = data_flow_ops.dynamic_partition(
-              inp, conditions_int, 2)
-        else:
-          if not pass_through:
-            done_inp = _stack(inp, [array_ops.size(done_indices)]).t
-          new_inp = inp
-
-        out_ta = output_tas[i]
         if not pass_through:
-          # Note that done_indices can be empty. done_inp should also be empty
-          # in that case.
-          out_ta = out_ta.scatter(done_indices, done_inp)
+          done_inp = _stack(inp, [array_ops.size(done_indices)]).t
+        new_inp = inp
+
       new_inputs.append(new_inp)
+      out_ta = output_tas[i]
+      if not pass_through:
+        # Note that done_indices can be empty. done_inp should also be empty
+        # in that case.
+        out_ta = out_ta.scatter(done_indices, done_inp)
       new_output_tas.append(out_ta)
 
     assert len(new_output_tas) == len(output_tas)
@@ -4862,37 +4773,7 @@
               outputs.append(init_values[i + 2])
             else:
               ta = output_tas[i]
-              if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
-                shape_and_type = _parse_variant_shapes_and_types(inp)[0]
-                length = list_ops.tensor_list_length(inp)
-
-                # We have been accumulating values in a:
-                #
-                #   List[user_list_len, List[loop_len, Tensor[...]]]
-                #
-                # We want to return an output in the same format as the input:
-                #
-                #   List[user_list_len, Tensor[loop_len, ...]]
-                #
-                # So we need to loop over the list and stack its contents.
-                def _stack_loop_body(index, output_list):
-                  current_value = ta.read(index)
-                  output_list = list_ops.tensor_list_set_item(
-                      output_list, index,
-                      list_ops.tensor_list_stack(
-                          current_value, shape_and_type.dtype))
-                  return index + 1, output_list
-
-                output_list = list_ops.tensor_list_reserve(
-                    tensor_shape.TensorShape(shape_and_type.shape), length,
-                    shape_and_type.dtype)
-                _, output_list = control_flow_ops.while_loop(
-                    lambda index, _: index < length,
-                    _stack_loop_body,
-                    [0, output_list])
-                outputs.append(output_list)
-              else:
-                outputs.append(ta.stack())
+              outputs.append(ta.stack())
           else:
             outputs.append(inp)
         return outputs