Add pfor converters and basic type information to optionals
Like TensorLists, pfor vectorizes the content rather than the variant-dtype tensor itself.
Fixes #44502.
PiperOrigin-RevId: 342694491
Change-Id: I4f0a7cab5c59cdae85e1c34d6a479b0dc562e560
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index 01b5985..e5f3303 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -84,4 +84,6 @@
ST_INVALID = 0;
// "tensorflow::TensorList" in the variant type registry.
ST_TENSOR_LIST = 1;
+ // "tensorflow::data::Optional" in the variant type registry.
+ ST_OPTIONAL = 2;
}
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index c4bd397..e00ebf7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -853,7 +853,18 @@
.Input("components: Toutput_types")
.Output("optional: variant")
.Attr("Toutput_types: list(type) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<DataType> dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
+ c->set_output(0, c->Scalar());
+ std::vector<shape_inference::ShapeAndType> shapes_and_types;
+ shapes_and_types.reserve(c->num_inputs());
+ for (int i = 0; i < c->num_inputs(); ++i) {
+ shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
+ }
+ c->set_output_handle_shapes_and_types(0, shapes_and_types);
+ return Status::OK();
+ });
REGISTER_OP("OptionalNone")
.Output("optional: variant")
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index edb02b0..0063b7f 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1844,6 +1844,24 @@
self.assertAllClose(compute_jacobian(use_pfor=True),
compute_jacobian(use_pfor=False))
+ def test_cond_func_grad_jacobian(self):
+
+ @def_function.function
+ def f(x):
+ y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
+ return y
+
+ with backprop.GradientTape(persistent=True) as tape:
+ x = constant_op.constant(1.)
+ tape.watch(x)
+ y = f(x)
+ grad = tape.gradient(y, x)
+ self.assertAllClose(3., grad)
+ jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
+ self.assertAllClose(6., jacobian)
+ jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
+ self.assertAllClose(6., jacobian_pfor)
+
@test_util.run_all_in_graph_and_eager_modes
class BatchJacobianTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index f641687..b765084 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -45,6 +45,7 @@
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker_v2
@@ -1159,6 +1160,21 @@
self._test_loop_fn(loop_fn, 2)
+class OptionalTest(PForTestCase):
+
+ def test_optional_from_value(self):
+
+ def loop_fn(i):
+ o = gen_dataset_ops.optional_from_value(
+ [i, i + 1, constant_op.constant(3)])
+ gen_dataset_ops.optional_none()
+ return gen_dataset_ops.optional_get_value(
+ o, [dtypes.int32, dtypes.int32, dtypes.int32],
+ [[], [], []])
+
+ self._test_loop_fn(loop_fn, 2)
+
+
class StackTest(PForTestCase):
@test_util.run_v1_only("b/122612051")
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 2b02d5e..c9431fa 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -46,6 +46,7 @@
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_list_ops
@@ -83,22 +84,45 @@
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if not handle_data.is_set:
return None
- if len(handle_data.shape_and_type) != 1:
- raise ValueError("Expected handle data of length 1, got {!r} of length {}"
- .format(handle_data, len(handle_data.shape_and_type)))
- return handle_data.shape_and_type[0]
+ return handle_data.shape_and_type
-def _is_tensor_list(t):
- """True if `t` is a TensorList, False if it isn't, None if unknown."""
+def _is_variant_with_internal_stacking(t):
+ """Identifies variant tensors which pfor always maintains as scalars.
+
+ For these, the pfor tensor is recorded as "stacked" if the content of the
+ variant tensor (e.g. the elements of a TensorList) are all stacked.
+
+ Args:
+ t: A tensor to identify.
+ Returns:
+ True if `t` is a TensorList/Optional, False not, None if unknown.
+ """
if t.dtype != dtypes.variant:
return False
- shape_and_type = _variant_handle_data(t)
- if shape_and_type is None:
- # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
- # can make this an error instead of assuming TensorLists have handle data.
- return None # Presumed not a TensorList
- return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
+ shapes_and_types = _variant_handle_data(t)
+ if shapes_and_types is None or not shapes_and_types:
+ # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
+ # make this an error instead of assuming TensorLists have handle data.
+ return None # Presumed not a TensorList/Optional
+ return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
+ shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
+
+
+def _parse_variant_shapes_and_types(t):
+ """Extracts shape and dtype information from a variant tensor `t`."""
+ shapes_and_types = _variant_handle_data(t)
+ if shapes_and_types is None or not shapes_and_types:
+ raise ValueError("Required handle data not set for {!r}".format(t))
+ if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+ return shapes_and_types
+ else:
+ if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
+ return shapes_and_types
+ else:
+ raise ValueError(
+ "Attempted to stack a variant-dtype tensor with no type set ({!r})"
+ .format(t))
def _stack(t, length):
@@ -109,23 +133,19 @@
# suitable since operations on stacked handles may expect a vectorized version
# of the variant.
if t.dtype == dtypes.variant:
- shape_and_type = _variant_handle_data(t)
- if shape_and_type is None:
- raise ValueError("Required handle data not set for {!r}".format(t))
- if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
+ shapes_and_types = _parse_variant_shapes_and_types(t)
+ if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+ if len(shapes_and_types) != 1:
+ raise ValueError(
+ "Expected handle data of length 1, got {!r} of length {}"
+ .format(shapes_and_types, len(shapes_and_types)))
return wrap(
- _stack_tensor_list(t, shape_and_type.dtype, length),
+ _stack_tensor_list(t, shapes_and_types[0].dtype, length),
True)
else:
- if shape_and_type.specialized_type != types_pb2.ST_INVALID:
- raise ValueError(
- ("Attempted to stack an unhandled variant-dtype tensor of "
- "type {!r} ({!r})").format(
- shape_and_type.specialized_type, t))
- else:
- raise ValueError(
- "Attempted to stack a variant-dtype tensor with no type set ({!r})"
- .format(t))
+ raise ValueError(
+ ("Attempted to stack an unhandled variant-dtype tensor of "
+ "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
ones = array_ops.ones_like(array_ops.shape(t))
ones = array_ops.reshape(ones, [-1])
length = array_ops.reshape(length, [-1])
@@ -1629,7 +1649,7 @@
else:
batch_dim = tensor_shape.TensorShape(loop_len)
output_shape = batch_dim.concatenate(output_shape)
- if _is_tensor_list(new_output.t):
+ if _is_variant_with_internal_stacking(new_output.t):
new_output.t.set_shape([])
else:
new_output.t.set_shape(output_shape)
@@ -3602,7 +3622,7 @@
def _tile_variant_with_length(t, length):
"""stacks `t` `length` times."""
- if _is_tensor_list(t):
+ if _is_variant_with_internal_stacking(t):
# The content of TensorLists is vectorized, not the variant itself.
return t
original_tensor = t
@@ -3622,16 +3642,41 @@
def _untile_variant(t):
- if _is_tensor_list(t):
+ if _is_variant_with_internal_stacking(t):
# The content of TensorLists is vectorized, not the variant itself.
if not t.shape.is_compatible_with([]):
raise AssertionError(
- "Unexpectedly saw a TensorList with non-scalar shape: {!r}"
- .format(t))
+ ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
+ "non-scalar shape: {!r}").format(t))
return t
return array_ops.gather(t, 0)
+@RegisterPFor("OptionalFromValue")
+def _convert_optional_from_value(pfor_input):
+ pfor_input.stack_inputs()
+ return wrap(
+ gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
+ True)
+
+
+@RegisterPFor("OptionalGetValue")
+def _convert_optional_get_value(pfor_input):
+ handle = pfor_input.stacked_input(0)
+ output_types = pfor_input.get_attr("output_types")
+ original_output_shapes = pfor_input.get_attr("output_shapes")
+ output_shapes = []
+ for shape in original_output_shapes:
+ shape = tensor_shape.TensorShape(shape)
+ loop_len_shape = tensor_shape.TensorShape(
+ [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
+ shape = loop_len_shape.concatenate(shape)
+ output_shapes.append(shape.as_proto())
+ results = gen_dataset_ops.optional_get_value(handle, output_types,
+ output_shapes)
+ return [wrap(t, True) for t in results]
+
+
@RegisterPFor("TensorListReserve")
def _convert_tensor_list_reserve(pfor_input):
element_shape = pfor_input.unstacked_input(0)
@@ -4275,7 +4320,7 @@
shape = shape.merge_with(output_shapes[i])
pfor_input = self._pfor_input.input(i)
if pfor_input.is_stacked:
- if _is_tensor_list(pfor_input.t):
+ if _is_variant_with_internal_stacking(pfor_input.t):
shape = tensor_shape.TensorShape([]).concatenate(shape)
else:
shape = tensor_shape.TensorShape([None]).concatenate(shape)