Add shape inference support for tf_device.cluster in shape inference pass.
This follows how shape inference is performed for passthrough ops like tf_device.launch.
PiperOrigin-RevId: 328241534
Change-Id: I83cc618291f5dc84ba18f963b413d5117e98de8c
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 3e61357..26df602 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -530,6 +530,21 @@
return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
}
+ // CHECK-LABEL: infer_device_cluster
+ func @infer_device_cluster(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) {
+ %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ %1 = "tf_device.cluster"() ({
+ %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32>
+ tf_device.return %2 : tensor<1x8x2xf32>
+ // CHECK: () -> tensor<1x8x2xf32>
+ }) : () -> tensor<*xf32>
+ // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
+ // CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
+ %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
+ %4 = addf %1, %1 : tensor<*xf32>
+ return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
+ }
+
// CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32>
func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> {
// CHECK: %[[RESULT:.*]] = tensor_cast
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 88ad787..0042216 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -116,12 +116,12 @@
// Returns if the shape inference pass supports an op outside the TF dialect.
bool IsSupportedNonTFOp(Operation* op) {
- return isa<ReturnOp, tf_device::ReturnOp, tf_executor::EnterOp,
- tf_executor::ExitOp, tf_executor::FetchOp, tf_executor::GraphOp,
- tf_executor::IslandOp, tf_executor::LoopCondOp,
- tf_executor::MergeOp, tf_executor::NextIterationSinkOp,
- tf_executor::SwitchNOp, tf_executor::SwitchOp,
- tf_executor::YieldOp>(op);
+ return isa<ReturnOp, tf_device::ReturnOp, tf_device::ClusterOp,
+ tf_device::LaunchOp, tf_executor::EnterOp, tf_executor::ExitOp,
+ tf_executor::FetchOp, tf_executor::GraphOp, tf_executor::IslandOp,
+ tf_executor::LoopCondOp, tf_executor::MergeOp,
+ tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
+ tf_executor::SwitchOp, tf_executor::YieldOp>(op);
}
// Returns whether a cast back would need to be inserted, e.g., whether the
@@ -745,6 +745,11 @@
return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
op->getResults());
}
+ if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
+ auto terminator = cluster_op.GetBody().getTerminator();
+ return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
+ op->getResults());
+ }
if (op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
return RefineShapeForPassThroughOps(op);
}