[MLIR:TF/XLA] Shape inference of variant types
1) use the same handling as resource when calling to registry's shape functions.
2) create some pass-through cases, because the shape functions often skip subshapes, which is now a problem for tensor list. Since a while loop's tensorlist output is usually passed to other ops, while's shape function skipping the handle shapes causes cascading problems.
PiperOrigin-RevId: 299238548
Change-Id: I63ced0bde726a04ddd2aa204eef636646626c4eb
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index c9db7e0..706524e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -254,4 +254,28 @@
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
return %0 : tensor<*xf32>
}
+
+ // CHECK-LABEL: func @while_variant
+ // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+ func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
+ // CHECK: tf.While
+ // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+ %0 = "tf.While"(%arg0) {cond = @variant_cond_func, body = @variant_body_func, is_stateless = true} : (tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant>
+ // CHECK: tf.ZerosLike
+ // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+ %1 = "tf.ZerosLike"(%0) : (tensor<!tf.variant>) -> tensor<!tf.variant>
+ // CHECK: tf.Identity
+ // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
+ %2 = "tf.Identity"(%1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
+ return %2 : tensor<!tf.variant>
+ }
+ // CHECK-LABEL: func @variant_cond_func
+ func @variant_cond_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<i1> {
+ %0 = "tf._SomeOp"() : () -> tensor<i1>
+ return %0 : tensor<i1>
+ }
+ // CHECK-LABEL: func @variant_body_func
+ func @variant_body_func(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant<tensor<16x1xf32>>> {
+ return %arg0 : tensor<!tf.variant<tensor<16x1xf32>>>
+ }
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index b3474e2..0a68780 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -31,6 +31,7 @@
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
+#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
@@ -184,20 +185,79 @@
return false;
}
+// Gets the subtype's shape and data type for `type`. Templated to support both
+// ResourceType and VariantType.
+template <typename T>
+std::unique_ptr<std::vector<
+ std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
+GetSubtypesHelper(Type type) {
+ auto type_with_subtypes =
+ type.cast<TensorType>().getElementType().dyn_cast<T>();
+ if (!type_with_subtypes || type_with_subtypes.getSubtypes().empty()) {
+ return nullptr;
+ }
+ auto shapes_and_types = absl::make_unique<std::vector<
+ std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
+ for (auto subtype : type_with_subtypes.getSubtypes()) {
+ auto shape = GetShapeFromMlirType(subtype);
+ // handle_shapes_and_types requires all shapes to be known. So if any
+ // subtype is unknown, clear the vector.
+ if (!shape) {
+ shapes_and_types = nullptr;
+ break;
+ }
+ tensorflow::DataType dtype;
+ auto status =
+ tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
+ assert(status.ok() && "Unknown element type");
+ shapes_and_types->emplace_back(*shape, dtype);
+ }
+ return shapes_and_types;
+}
+
+// Gets the subtype's shape and data type for `type`.
+std::unique_ptr<std::vector<
+ std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>
+GetSubtypes(Type type) {
+ auto subclasses = GetSubtypesHelper<TF::ResourceType>(type);
+ if (subclasses) return subclasses;
+ return GetSubtypesHelper<TF::VariantType>(type);
+}
+
+// Makes result types match the operand types. Returns if anything is changed.
+bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
+ bool changed = false;
+ for (auto entry : llvm::zip(operands, results)) {
+ Type operand_type = std::get<0>(entry).getType();
+ if (operand_type == std::get<1>(entry).getType()) continue;
+ std::get<1>(entry).setType(operand_type);
+ changed = true;
+ }
+ return changed;
+}
+
} // namespace
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
int64_t graph_version) {
assert(tf_dialect == op->getDialect());
+ // The shape function of these ops sometimes does not propagate subtypes
+ // (handle shapes) for resource and variant types. We use a simple passthrough
+ // to make sure they are preserved in the output.
+ if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
+ isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
+ return PassThroughOperandTypes(op->getOperands(), op->getResults());
+ }
// If no result for this op needs shape inference, we have a fast-path return.
- // But if the type is a resource, we do not skip it because we might not have
- // the handle shapes.
+ // But if the type is a resource/variant, we do not skip it because we might
+ // not have the handle shapes.
if (llvm::all_of(op->getResultTypes(), [](Type type) {
auto shape_type = type.dyn_cast<ShapedType>();
return !shape_type ||
(shape_type.hasStaticShape() &&
- !shape_type.getElementType().isa<TF::ResourceType>());
+ !shape_type.getElementType().isa<TF::ResourceType>() &&
+ !shape_type.getElementType().isa<TF::VariantType>());
})) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
<< op->getName() << "'.\n";);
@@ -282,29 +342,8 @@
if (auto shape = GetShapeFromMlirType(operand_type)) {
input_shapes[index] = *shape;
}
- // Collect the handle shapes and types for a resource.
- if (auto resource_type = operand_type.cast<TensorType>()
- .getElementType()
- .dyn_cast<TF::ResourceType>()) {
- if (resource_type.getSubtypes().empty()) continue;
- auto shapes_and_types = absl::make_unique<std::vector<
- std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
- for (auto subtype : resource_type.getSubtypes()) {
- auto shape = GetShapeFromMlirType(subtype);
- // handle_shapes_and_types requires all shapes to be known. So if any
- // subtype is unknown, clear the vector.
- if (!shape) {
- shapes_and_types = nullptr;
- break;
- }
- tensorflow::DataType dtype;
- auto status =
- tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
- assert(status.ok() && "Unknown element type");
- shapes_and_types->emplace_back(*shape, dtype);
- }
- handle_shapes_and_types[index] = std::move(shapes_and_types);
- }
+ // Collect the handle shapes and types for a resource/variant.
+ handle_shapes_and_types[index] = GetSubtypes(operand_type);
}
// Perform the shape inference using an InferenceContext with the input
@@ -346,8 +385,9 @@
return RankedTensorType::get(shape, element_type);
};
auto new_element_type = shaped_type.getElementType();
- // Populate the handle shapes for a resource.
- if (auto resource_type = new_element_type.dyn_cast<TF::ResourceType>()) {
+ // Populate the handle shapes for a resource/variant.
+ if (new_element_type.isa<TF::ResourceType>() ||
+ new_element_type.isa<TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
llvm::SmallVector<mlir::TensorType, 1> subtypes;
@@ -359,7 +399,11 @@
assert(status.ok() && "Unknown element type");
subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type));
}
- new_element_type = TF::ResourceType::get(subtypes, op->getContext());
+ if (new_element_type.isa<TF::ResourceType>()) {
+ new_element_type = TF::ResourceType::get(subtypes, op->getContext());
+ } else {
+ new_element_type = TF::VariantType::get(subtypes, op->getContext());
+ }
}
}
auto new_type = get_tensor_type(shape_handle, new_element_type);