PFor: Support TensorLists in the while_loop converter when the condition is pfor-loop-variant
Since they use internal stacking, they need to be accumulated differently.
PiperOrigin-RevId: 382784481
Change-Id: I1628178d61e0f7a9158b0ee57d37244e45d93297
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index 465d317..db28ad1 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1370,7 +1370,6 @@
self._test_loop_fn(loop_fn, 2)
- @test_util.enable_control_flow_v2
def test_tensor_list_reserve_while_loop(self):
# Here a loop invariant TensorList is captured by a while_loop, which then
# performs loop dependent operations on it, resulting in a loop variant
@@ -1378,6 +1377,8 @@
# while_loop.
# We handle this particular case by forcing vectorization of
# TensorListReserve operation.
+ v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled()
+ control_flow_v2_toggles.enable_control_flow_v2()
def loop_fn(i):
handle = list_ops.tensor_list_reserve([], 2, dtypes.int32)
@@ -1387,32 +1388,8 @@
return list_ops.tensor_list_stack(out_handle, dtypes.int32)
self._test_loop_fn(loop_fn, 2)
-
- @test_util.enable_control_flow_v2
- def test_tensor_list_while_loop_stacked_cond_stacked_list(self):
-
- def loop_fn(i):
- handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, i], [])
- _, out_handle = control_flow_ops.while_loop(
- lambda j, _: j < i,
- lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
- (0, handle))
- return list_ops.tensor_list_stack(out_handle, dtypes.int32)
-
- self._test_loop_fn(loop_fn, 5)
-
- @test_util.enable_control_flow_v2
- def test_tensor_list_while_loop_stacked_cond_unstacked_list(self):
-
- def loop_fn(i):
- handle = list_ops.tensor_list_from_tensor([20, 21, 22, 23, 24], [])
- _, out_handle = control_flow_ops.while_loop(
- lambda j, _: j < i,
- lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)),
- (0, handle))
- return list_ops.tensor_list_stack(out_handle, dtypes.int32)
-
- self._test_loop_fn(loop_fn, 5)
+ if not v2_enabled:
+ control_flow_v2_toggles.disable_control_flow_v2()
def test_tensor_list_addn_already_stacked(self):
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index ae3f13b..e5f3f4c 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -88,23 +88,6 @@
return handle_data.shape_and_type
-def _variant_type_id(t):
- """Returns the full_type_pb2 type of `t`, or None if it is not available."""
- if t.dtype != dtypes.variant:
- return None
- shapes_and_types = _variant_handle_data(t)
- if shapes_and_types is None or not shapes_and_types:
- # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
- # make this an error instead of assuming TensorLists have handle data.
- return None # Presumed not a TensorList/Optional
- return shapes_and_types[0].type.type_id
-
-
-_INTERNAL_STACKING_TYPE_IDS = (
- full_type_pb2.TFT_ARRAY,
- full_type_pb2.TFT_OPTIONAL)
-
-
def _is_variant_with_internal_stacking(t):
"""Identifies variant tensors which pfor always maintains as scalars.
@@ -116,8 +99,15 @@
Returns:
True if `t` is a TensorList/Optional, False not, None if unknown.
"""
- type_id = _variant_type_id(t)
- return type_id in _INTERNAL_STACKING_TYPE_IDS
+ if t.dtype != dtypes.variant:
+ return False
+ shapes_and_types = _variant_handle_data(t)
+ if shapes_and_types is None or not shapes_and_types:
+ # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
+ # make this an error instead of assuming TensorLists have handle data.
+ return None # Presumed not a TensorList/Optional
+ type_id = shapes_and_types[0].type.type_id
+ return type_id in (full_type_pb2.TFT_ARRAY, full_type_pb2.TFT_OPTIONAL)
def _parse_variant_shapes_and_types(t):
@@ -4536,60 +4526,11 @@
with ops.name_scope("while_init"):
for inp in self._pfor_input.inputs:
inputs.append(inp.t)
- variant_type_id = _variant_type_id(inp.t)
- if variant_type_id in _INTERNAL_STACKING_TYPE_IDS:
- if variant_type_id != full_type_pb2.TFT_ARRAY:
- raise NotImplementedError(
- ("While loop conversion is only supported for TensorLists. Got "
- "another variant {}, probably an optional. Please file a bug.")
- .format(inp.t))
- # For TensorLists, the input format is:
- #
- # List[user_list_len, Tensor[loop_len, ...]]
- #
- # rather than the usual
- #
- # Tensor[loop_len, ...]
- #
- # The body of the loop will take and return lists in this "internal
- # vectorization" format, so we want to keep it that way as much as
- # possible. We'll accumulate finished iterations (only relevant for
- # pfor-loop-variant while_loop conditions) in an accumulator with
- # type:
- #
- # List[user_list_len, List[loop_len, Tensor[...]]]
- #
- # This means that each while_loop iteration, we'll iterate over the
- # length of the TensorList, dividing done/remaining pfor loop indices
- # and scattering the done indices into the inner nested list of the
- # accumulator.
- element_shape = list_ops.tensor_list_element_shape(
- inp.t, dtypes.int32)[1:]
- dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype
-
- def _init_loop_body(index, output_ta):
- output_ta = output_ta.write(
- index,
- list_ops.tensor_list_reserve(element_shape, loop_len, dtype))
- return index + 1, output_ta
-
- length = list_ops.tensor_list_length(inp.t)
- output_ta = tensor_array_ops.TensorArray(
- inp.t.dtype, # Variant; this is a nested TensorList
- size=length,
- dynamic_size=True,
- infer_shape=False)
- _, output_ta = control_flow_ops.while_loop(
- lambda index, _: index < length,
- _init_loop_body,
- [0, output_ta])
- else:
- output_ta = tensor_array_ops.TensorArray(
+ output_tas.append(tensor_array_ops.TensorArray(
inp.t.dtype,
size=loop_len,
dynamic_size=False,
- infer_shape=True)
- output_tas.append(output_ta)
+ infer_shape=True))
# See documentation for __call__ for the structure of init_values.
indices = (
math_ops.range(self._pfor.loop_len_vector[0])
@@ -4617,51 +4558,21 @@
new_output_tas = []
for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
pass_through = i in self._body_pass_through_indices
- if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
- shape_and_type = _parse_variant_shapes_and_types(inp)[0]
- element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32)
- user_list_len = list_ops.tensor_list_length(inp)
-
- def _split_vectorized_ta_element(index, new_inp, new_out_ta):
- elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype,
- element_shape)
- if stacked:
- done_elem, new_elem = data_flow_ops.dynamic_partition(
- elem, conditions_int, 2)
- new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem)
- else:
- done_elem = _stack(elem, [array_ops.size(done_indices)]).t
- done_accum = new_out_ta.read(index)
- done_accum = list_ops.tensor_list_scatter(
- tensor=done_elem, indices=done_indices, input_handle=done_accum)
- new_out_ta = new_out_ta.write(index, done_accum)
- return index + 1, new_inp, new_out_ta
-
- length = list_ops.tensor_list_length(inp)
- new_inp = list_ops.tensor_list_reserve(
- tensor_shape.TensorShape([None])
- + tensor_shape.TensorShape(shape_and_type.shape)[1:],
- user_list_len, shape_and_type.dtype)
- _, new_inp, out_ta = control_flow_ops.while_loop(
- lambda index, unused_new_inp, unused_new_out_ta: index < length,
- _split_vectorized_ta_element,
- [0, new_inp, output_tas[i]])
+ # Partition the inputs.
+ if stacked:
+ done_inp, new_inp = data_flow_ops.dynamic_partition(
+ inp, conditions_int, 2)
else:
- # Partition the inputs.
- if stacked:
- done_inp, new_inp = data_flow_ops.dynamic_partition(
- inp, conditions_int, 2)
- else:
- if not pass_through:
- done_inp = _stack(inp, [array_ops.size(done_indices)]).t
- new_inp = inp
-
- out_ta = output_tas[i]
if not pass_through:
- # Note that done_indices can be empty. done_inp should also be empty
- # in that case.
- out_ta = out_ta.scatter(done_indices, done_inp)
+ done_inp = _stack(inp, [array_ops.size(done_indices)]).t
+ new_inp = inp
+
new_inputs.append(new_inp)
+ out_ta = output_tas[i]
+ if not pass_through:
+ # Note that done_indices can be empty. done_inp should also be empty
+ # in that case.
+ out_ta = out_ta.scatter(done_indices, done_inp)
new_output_tas.append(out_ta)
assert len(new_output_tas) == len(output_tas)
@@ -4862,37 +4773,7 @@
outputs.append(init_values[i + 2])
else:
ta = output_tas[i]
- if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
- shape_and_type = _parse_variant_shapes_and_types(inp)[0]
- length = list_ops.tensor_list_length(inp)
-
- # We have been accumulating values in a:
- #
- # List[user_list_len, List[loop_len, Tensor[...]]]
- #
- # We want to return an output in the same format as the input:
- #
- # List[user_list_len, Tensor[loop_len, ...]]
- #
- # So we need to loop over the list and stack its contents.
- def _stack_loop_body(index, output_list):
- current_value = ta.read(index)
- output_list = list_ops.tensor_list_set_item(
- output_list, index,
- list_ops.tensor_list_stack(
- current_value, shape_and_type.dtype))
- return index + 1, output_list
-
- output_list = list_ops.tensor_list_reserve(
- tensor_shape.TensorShape(shape_and_type.shape), length,
- shape_and_type.dtype)
- _, output_list = control_flow_ops.while_loop(
- lambda index, _: index < length,
- _stack_loop_body,
- [0, output_list])
- outputs.append(output_list)
- else:
- outputs.append(ta.stack())
+ outputs.append(ta.stack())
else:
outputs.append(inp)
return outputs