Add pfor converters and basic type information to optionals

Like TensorLists, pfor vectorizes the content rather than the variant-dtype tensor itself.

Fixes #44502.

PiperOrigin-RevId: 342694491
Change-Id: I4f0a7cab5c59cdae85e1c34d6a479b0dc562e560
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index 01b5985..e5f3303 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -84,4 +84,6 @@
   ST_INVALID = 0;
   // "tensorflow::TensorList" in the variant type registry.
   ST_TENSOR_LIST = 1;
+  // "tensorflow::data::Optional" in the variant type registry.
+  ST_OPTIONAL = 2;
 }
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index c4bd397..e00ebf7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -853,7 +853,18 @@
     .Input("components: Toutput_types")
     .Output("optional: variant")
     .Attr("Toutput_types: list(type) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      std::vector<DataType> dtypes;
+      TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
+      c->set_output(0, c->Scalar());
+      std::vector<shape_inference::ShapeAndType> shapes_and_types;
+      shapes_and_types.reserve(c->num_inputs());
+      for (int i = 0; i < c->num_inputs(); ++i) {
+        shapes_and_types.emplace_back(c->input(i), dtypes[i], ST_OPTIONAL);
+      }
+      c->set_output_handle_shapes_and_types(0, shapes_and_types);
+      return Status::OK();
+    });
 
 REGISTER_OP("OptionalNone")
     .Output("optional: variant")
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index edb02b0..0063b7f 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1844,6 +1844,24 @@
     self.assertAllClose(compute_jacobian(use_pfor=True),
                         compute_jacobian(use_pfor=False))
 
+  def test_cond_func_grad_jacobian(self):
+
+    @def_function.function
+    def f(x):
+      y = control_flow_ops.cond(x > 0., lambda: x**3., lambda: x**2.)
+      return y
+
+    with backprop.GradientTape(persistent=True) as tape:
+      x = constant_op.constant(1.)
+      tape.watch(x)
+      y = f(x)
+      grad = tape.gradient(y, x)
+    self.assertAllClose(3., grad)
+    jacobian = tape.jacobian(grad, x, experimental_use_pfor=False)
+    self.assertAllClose(6., jacobian)
+    jacobian_pfor = tape.jacobian(grad, x, experimental_use_pfor=True)
+    self.assertAllClose(6., jacobian_pfor)
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class BatchJacobianTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index f641687..b765084 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -45,6 +45,7 @@
 from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_list_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import gradient_checker_v2
@@ -1159,6 +1160,21 @@
     self._test_loop_fn(loop_fn, 2)
 
 
+class OptionalTest(PForTestCase):
+
+  def test_optional_from_value(self):
+
+    def loop_fn(i):
+      o = gen_dataset_ops.optional_from_value(
+          [i, i + 1, constant_op.constant(3)])
+      gen_dataset_ops.optional_none()
+      return gen_dataset_ops.optional_get_value(
+          o, [dtypes.int32, dtypes.int32, dtypes.int32],
+          [[], [], []])
+
+    self._test_loop_fn(loop_fn, 2)
+
+
 class StackTest(PForTestCase):
 
   @test_util.run_v1_only("b/122612051")
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 2b02d5e..c9431fa 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -46,6 +46,7 @@
 from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_image_ops
 from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import gen_list_ops
@@ -83,22 +84,45 @@
   handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
   if not handle_data.is_set:
     return None
-  if len(handle_data.shape_and_type) != 1:
-    raise ValueError("Expected handle data of length 1, got {!r} of length {}"
-                     .format(handle_data, len(handle_data.shape_and_type)))
-  return handle_data.shape_and_type[0]
+  return handle_data.shape_and_type
 
 
-def _is_tensor_list(t):
-  """True if `t` is a TensorList, False if it isn't, None if unknown."""
+def _is_variant_with_internal_stacking(t):
+  """Identifies variant tensors which pfor always maintains as scalars.
+
+  For these, the pfor tensor is recorded as "stacked" if the content of the
+  variant tensor (e.g. the elements of a TensorList) are all stacked.
+
+  Args:
+    t: A tensor to identify.
+  Returns:
+    True if `t` is a TensorList/Optional, False not, None if unknown.
+  """
   if t.dtype != dtypes.variant:
     return False
-  shape_and_type = _variant_handle_data(t)
-  if shape_and_type is None:
-    # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
-    # can make this an error instead of assuming TensorLists have handle data.
-    return None  # Presumed not a TensorList
-  return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
+  shapes_and_types = _variant_handle_data(t)
+  if shapes_and_types is None or not shapes_and_types:
+    # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
+    # make this an error instead of assuming TensorLists have handle data.
+    return None  # Presumed not a TensorList/Optional
+  return (shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST or
+          shapes_and_types[0].specialized_type == types_pb2.ST_OPTIONAL)
+
+
+def _parse_variant_shapes_and_types(t):
+  """Extracts shape and dtype information from a variant tensor `t`."""
+  shapes_and_types = _variant_handle_data(t)
+  if shapes_and_types is None or not shapes_and_types:
+    raise ValueError("Required handle data not set for {!r}".format(t))
+  if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+    return shapes_and_types
+  else:
+    if shapes_and_types[0].specialized_type != types_pb2.ST_INVALID:
+      return shapes_and_types
+    else:
+      raise ValueError(
+          "Attempted to stack a variant-dtype tensor with no type set ({!r})"
+          .format(t))
 
 
 def _stack(t, length):
@@ -109,23 +133,19 @@
   # suitable since operations on stacked handles may expect a vectorized version
   # of the variant.
   if t.dtype == dtypes.variant:
-    shape_and_type = _variant_handle_data(t)
-    if shape_and_type is None:
-      raise ValueError("Required handle data not set for {!r}".format(t))
-    if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
+    shapes_and_types = _parse_variant_shapes_and_types(t)
+    if shapes_and_types[0].specialized_type == types_pb2.ST_TENSOR_LIST:
+      if len(shapes_and_types) != 1:
+        raise ValueError(
+            "Expected handle data of length 1, got {!r} of length {}"
+            .format(shapes_and_types, len(shapes_and_types)))
       return wrap(
-          _stack_tensor_list(t, shape_and_type.dtype, length),
+          _stack_tensor_list(t, shapes_and_types[0].dtype, length),
           True)
     else:
-      if shape_and_type.specialized_type != types_pb2.ST_INVALID:
-        raise ValueError(
-            ("Attempted to stack an unhandled variant-dtype tensor of "
-             "type {!r} ({!r})").format(
-                 shape_and_type.specialized_type, t))
-      else:
-        raise ValueError(
-            "Attempted to stack a variant-dtype tensor with no type set ({!r})"
-            .format(t))
+      raise ValueError(
+          ("Attempted to stack an unhandled variant-dtype tensor of "
+           "type {!r} ({!r})").format(shapes_and_types[0].specialized_type, t))
   ones = array_ops.ones_like(array_ops.shape(t))
   ones = array_ops.reshape(ones, [-1])
   length = array_ops.reshape(length, [-1])
@@ -1629,7 +1649,7 @@
                 else:
                   batch_dim = tensor_shape.TensorShape(loop_len)
                 output_shape = batch_dim.concatenate(output_shape)
-              if _is_tensor_list(new_output.t):
+              if _is_variant_with_internal_stacking(new_output.t):
                 new_output.t.set_shape([])
               else:
                 new_output.t.set_shape(output_shape)
@@ -3602,7 +3622,7 @@
 
 def _tile_variant_with_length(t, length):
   """stacks `t` `length` times."""
-  if _is_tensor_list(t):
+  if _is_variant_with_internal_stacking(t):
     # The content of TensorLists is vectorized, not the variant itself.
     return t
   original_tensor = t
@@ -3622,16 +3642,41 @@
 
 
 def _untile_variant(t):
-  if _is_tensor_list(t):
+  if _is_variant_with_internal_stacking(t):
     # The content of TensorLists is vectorized, not the variant itself.
     if not t.shape.is_compatible_with([]):
       raise AssertionError(
-          "Unexpectedly saw a TensorList with non-scalar shape: {!r}"
-          .format(t))
+          ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
+           "non-scalar shape: {!r}").format(t))
     return t
   return array_ops.gather(t, 0)
 
 
+@RegisterPFor("OptionalFromValue")
+def _convert_optional_from_value(pfor_input):
+  pfor_input.stack_inputs()
+  return wrap(
+      gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]),
+      True)
+
+
+@RegisterPFor("OptionalGetValue")
+def _convert_optional_get_value(pfor_input):
+  handle = pfor_input.stacked_input(0)
+  output_types = pfor_input.get_attr("output_types")
+  original_output_shapes = pfor_input.get_attr("output_shapes")
+  output_shapes = []
+  for shape in original_output_shapes:
+    shape = tensor_shape.TensorShape(shape)
+    loop_len_shape = tensor_shape.TensorShape(
+        [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)])
+    shape = loop_len_shape.concatenate(shape)
+    output_shapes.append(shape.as_proto())
+  results = gen_dataset_ops.optional_get_value(handle, output_types,
+                                               output_shapes)
+  return [wrap(t, True) for t in results]
+
+
 @RegisterPFor("TensorListReserve")
 def _convert_tensor_list_reserve(pfor_input):
   element_shape = pfor_input.unstacked_input(0)
@@ -4275,7 +4320,7 @@
       shape = shape.merge_with(output_shapes[i])
       pfor_input = self._pfor_input.input(i)
       if pfor_input.is_stacked:
-        if _is_tensor_list(pfor_input.t):
+        if _is_variant_with_internal_stacking(pfor_input.t):
           shape = tensor_shape.TensorShape([]).concatenate(shape)
         else:
           shape = tensor_shape.TensorShape([None]).concatenate(shape)