Handle _XlaHostComputeMlir op in shape inference pass.
To infer shape for _XlaHostComputeMlir op, the module containing the function representing the host computation is deserialized and the function is used to infer the output shape of _XlaHostComputeMlir.
PiperOrigin-RevId: 366277002
Change-Id: I6394390ff49144745f86f44b21e9ba9d91982d21
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 0cd7b35..8156733 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -1170,6 +1170,7 @@
%1 = shape.shape_of %0 : tensor<*xindex> -> tensor<?xindex>
return %0, %1 : tensor<*xindex>, tensor<?xindex>
}
+
// CHECK-LABEL: func @partitioned_called_const_index
// CHECK-SAME: -> tensor<index>
func @partitioned_called_const_index() -> (tensor<*xindex>) {
@@ -1192,4 +1193,22 @@
return %1#1 : tensor<*x!quant.uniform<u8:f32, 0.007:128>>
}
+
+ // CHECK-LABEL: func @xla_host_compute_mlir_empty_module
+ func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+ // CHECK: "tf._XlaHostComputeMlir"
+ // CHECK-SAME: -> tensor<*xf32>
+ %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0, host_mlir_module = ""} : (tensor<2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
+
+ // CHECK-LABEL: func @xla_host_compute_mlir_shape_inferred
+ func @xla_host_compute_mlir_shape_inferred(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+ // CHECK: "tf._XlaHostComputeMlir"
+ // CHECK-SAME: -> tensor<2xf32>
+ // CHECK: return
+ // CHECK-SAME: tensor<2xf32>
+ %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0, host_mlir_module = "module {\0A func @host_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {\0A %0 = \22tf.Identity\22(%arg0) {_xla_outside_compilation = \22cluster1\22} : (tensor<*xf32>) -> tensor<*xf32> \0A return %0 : tensor<*xf32> \0A } \0A} \0A"} : (tensor<2xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+ }
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index eb91d43..02ac830 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -59,6 +59,7 @@
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -725,6 +726,10 @@
// yields.
bool InferShapeForIfRegion(IfRegionOp op);
+ // Infers the shape of _XlaHostComputeMlir based on the host computation
+ // module. Returns true if a return type was changed.
+ bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
+
// Infers the shape of ops that create TensorList. Specifically,
// TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
// refines the element shape if all tensors written to the list across all
@@ -886,6 +891,51 @@
return changed;
}
+bool ShapeInference::InferShapeForXlaHostComputeMlir(
+ _XlaHostComputeMlirOp host_compute_op) {
+ // Extract the module and function.
+ // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
+ // attribute is well formed, so we just return in case of an error in
+ // extracting the host function since it should never occur.
+ StringAttr host_module =
+ host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
+ if (host_module.getValue().empty()) return false;
+
+ mlir::OwningModuleRef module_for_func;
+ if (!tensorflow::DeserializeMlirModule(host_module.getValue().str(),
+ host_compute_op->getContext(),
+ &module_for_func)
+ .ok()) {
+ return false;
+ }
+
+ FuncOp func = module_for_func->lookupSymbol<FuncOp>("host_func");
+ if (!func) return false;
+
+ // Update/use input shapes for function.
+ FunctionType func_type = func.getType();
+ func.setType(FunctionType::get(func.getContext(),
+ host_compute_op.getOperandTypes(),
+ func_type.getResults()));
+
+ // Run shape inference on the function.
+ if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
+ {&func.getBody()}, 10)))
+ return false;
+ if (failed(InferShapeForFunctionReturnType(func))) return false;
+
+ bool changed = false;
+ // Use refined function return shape for XlaHostComputeMlirOp.
+ for (auto result :
+ zip(host_compute_op.getResults(), func.getType().getResults())) {
+ changed = RefineResultType(host_compute_op, std::get<0>(result),
+ std::get<1>(result)) ||
+ changed;
+ }
+
+ return changed;
+}
+
bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
DCOMMENT_OP(op, "Inferring shape for TensorList ");
Value handle = op->getResult(0);
@@ -1199,6 +1249,10 @@
while_region,
while_region.body().front().getTerminator()->getOperandTypes());
+ if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
+ return InferShapeForXlaHostComputeMlir(host_compute_op);
+ }
+
// Handle TensorList init operations by inferring shape from TensorList write
// operations. If we are unable to refine element shape here, proceed to use
// the InferenceContext below to get more precise shapes.