[MLIR:TF/XLA] Shape inference of variant types

1) use the same handling as resource when calling to registry's shape functions.
2) create some pass-through cases, because the shape functions often skip subshapes, which is now a problem for tensor list. Since a while loop's tensorlist output is usually passed to other ops, while's shape function skipping the handle shapes causes cascading problems.

PiperOrigin-RevId: 299238548
Change-Id: I63ced0bde726a04ddd2aa204eef636646626c4eb
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index c9db7e0..706524e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -254,4 +254,28 @@
     %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
     return %0 : tensor<*xf32>
   }
+
+  // CHECK-LABEL: func @while_variant
+  // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+  func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
+    // CHECK: tf.While
+    // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+    %0 = "tf.While"(%arg0) {cond = @variant_cond_func, body = @variant_body_func, is_stateless = true} : (tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant>
+    // CHECK: tf.ZerosLike
+    // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+    %1 = "tf.ZerosLike"(%0) : (tensor<!tf.variant>) -> tensor<!tf.variant>
+    // CHECK: tf.Identity
+    // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+    %2 = "tf.Identity"(%1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
+    return %2 : tensor<!tf.variant>
+  }
+  // CHECK-LABEL: func @variant_cond_func
+  func @variant_cond_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<i1> {
+    %0 = "tf._SomeOp"() : () -> tensor<i1>
+    return %0 : tensor<i1>
+  }
+  // CHECK-LABEL: func @variant_body_func
+  func @variant_body_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant<tensor<16x1xf32>>> {
+    return %arg0 : tensor<!tf.variant<tensor<16x1xf32>>>
+  }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index b3474e2..0a68780 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -31,6 +31,7 @@
 #include "mlir/IR/Diagnostics.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
+#include "mlir/IR/OperationSupport.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 #include "mlir/IR/SymbolTable.h"  // TF:llvm-project
 #include "mlir/IR/Value.h"  // TF:llvm-project
@@ -184,20 +185,79 @@
   return false;
 }
 
+// Gets the subtype's shape and data type for `type`. Templated to support both
+// ResourceType and VariantType.
+template <typename T>
+std::unique_ptr<std::vector<
+    std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
+GetSubtypesHelper(Type type) {
+  auto type_with_subtypes =
+      type.cast<TensorType>().getElementType().dyn_cast<T>();
+  if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) {
+    return nullptr;
+  }
+  auto shapes_and_types = absl::make_unique<std::vector<
+      std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
+  for (auto subtype : type_with_subtypes.getSubtypes()) {
+    auto shape = GetShapeFromMlirType(subtype);
+    // handle_shapes_and_types requires all shapes to be known. So if any
+    // subtype is unknown, clear the vector.
+    if (!shape) {
+      shapes_and_types = nullptr;
+      break;
+    }
+    tensorflow::DataType dtype;
+    auto status =
+        tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
+    assert(status.ok() && "Unknown element type");
+    shapes_and_types->emplace_back(*shape, dtype);
+  }
+  return shapes_and_types;
+}
+
+// Gets the subtype's shape and data type for `type`.
+std::unique_ptr<std::vector<
+    std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
+GetSubtypes(Type type) {
+  auto subclasses = GetSubtypesHelper<TF::ResourceType>(type);
+  if (subclasses) return subclasses;
+  return GetSubtypesHelper<TF::VariantType>(type);
+}
+
+// Makes result types match the operand types. Returns if anything is changed.
+bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
+  bool changed = false;
+  for (auto entry : llvm::zip(operands, results)) {
+    Type operand_type = std::get<0>(entry).getType();
+    if (operand_type == std::get<1>(entry).getType()) continue;
+    std::get<1>(entry).setType(operand_type);
+    changed = true;
+  }
+  return changed;
+}
+
 }  // namespace
 
 bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
                                   int64_t graph_version) {
   assert(tf_dialect == op->getDialect());
+  // The shape function of these ops sometimes does not propagate subtypes
+  // (handle shapes) for resource and variant types. We use a simple passthrough
+  // to make sure they are preserved in the output.
+  if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
+      isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
+    return PassThroughOperandTypes(op->getOperands(), op->getResults());
+  }
 
   // If no result for this op needs shape inference, we have a fast-path return.
-  // But if the type is a resource, we do not skip it because we might not have
-  // the handle shapes.
+  // But if the type is a resource/variant, we do not skip it because we might
+  // not have the handle shapes.
   if (llvm::all_of(op->getResultTypes(), [](Type type) {
         auto shape_type = type.dyn_cast<ShapedType>();
         return !shape_type ||
                (shape_type.hasStaticShape() &&
-                !shape_type.getElementType().isa<TF::ResourceType>());
+                !shape_type.getElementType().isa<TF::ResourceType>() &&
+                !shape_type.getElementType().isa<TF::VariantType>());
       })) {
     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
                             << op->getName() << "'.\n";);
@@ -282,29 +342,8 @@
     if (auto shape = GetShapeFromMlirType(operand_type)) {
       input_shapes[index] = *shape;
     }
-    // Collect the handle shapes and types for a resource.
-    if (auto resource_type = operand_type.cast<TensorType>()
-                                 .getElementType()
-                                 .dyn_cast<TF::ResourceType>()) {
-      if (resource_type.getSubtypes().empty()) continue;
-      auto shapes_and_types = absl::make_unique<std::vector<
-          std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
-      for (auto subtype : resource_type.getSubtypes()) {
-        auto shape = GetShapeFromMlirType(subtype);
-        // handle_shapes_and_types requires all shapes to be known. So if any
-        // subtype is unknown, clear the vector.
-        if (!shape) {
-          shapes_and_types = nullptr;
-          break;
-        }
-        tensorflow::DataType dtype;
-        auto status =
-            tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
-        assert(status.ok() && "Unknown element type");
-        shapes_and_types->emplace_back(*shape, dtype);
-      }
-      handle_shapes_and_types[index] = std::move(shapes_and_types);
-    }
+    // Collect the handle shapes and types for a resource/variant.
+    handle_shapes_and_types[index] = GetSubtypes(operand_type);
   }
 
   // Perform the shape inference using an InferenceContext with the input
@@ -346,8 +385,9 @@
       return RankedTensorType::get(shape, element_type);
     };
     auto new_element_type = shaped_type.getElementType();
-    // Populate the handle shapes for a resource.
-    if (auto resource_type = new_element_type.dyn_cast<TF::ResourceType>()) {
+    // Populate the handle shapes for a resource/variant.
+    if (new_element_type.isa<TF::ResourceType>() ||
+        new_element_type.isa<TF::VariantType>()) {
       auto handle_shapes_types = c.output_handle_shapes_and_types(output);
       if (handle_shapes_types) {
         llvm::SmallVector<mlir::TensorType, 1> subtypes;
@@ -359,7 +399,11 @@
           assert(status.ok() && "Unknown element type");
           subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type));
         }
-        new_element_type = TF::ResourceType::get(subtypes, op->getContext());
+        if (new_element_type.isa<TF::ResourceType>()) {
+          new_element_type = TF::ResourceType::get(subtypes, op->getContext());
+        } else {
+          new_element_type = TF::VariantType::get(subtypes, op->getContext());
+        }
       }
     }
     auto new_type = get_tensor_type(shape_handle, new_element_type);