Handle _XlaHostComputeMlir op in shape inference pass.

To infer shape for _XlaHostComputeMlir op, the module containing the function representing the host computation is deserialized and the function is used to infer the output shape of _XlaHostComputeMlir.

PiperOrigin-RevId: 366277002
Change-Id: I6394390ff49144745f86f44b21e9ba9d91982d21
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 0cd7b35..8156733 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -1170,6 +1170,7 @@
     %1 = shape.shape_of %0 : tensor<*xindex> -> tensor<?xindex>
     return %0, %1 : tensor<*xindex>, tensor<?xindex>
   }
+
   // CHECK-LABEL: func @partitioned_called_const_index
   // CHECK-SAME: -> tensor<index>
   func @partitioned_called_const_index() -> (tensor<*xindex>) {
@@ -1192,4 +1193,22 @@
 
     return %1#1 : tensor<*x!quant.uniform<u8:f32, 0.007:128>>
   }
+
+  // CHECK-LABEL: func @xla_host_compute_mlir_empty_module
+  func @xla_host_compute_mlir_empty_module(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+    // CHECK: "tf._XlaHostComputeMlir"
+    // CHECK-SAME: -> tensor<*xf32>
+    %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0, host_mlir_module = ""} : (tensor<2xf32>) -> tensor<*xf32>
+    return %0 : tensor<*xf32>
+  }
+
+  // CHECK-LABEL: func @xla_host_compute_mlir_shape_inferred
+  func @xla_host_compute_mlir_shape_inferred(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+    // CHECK: "tf._XlaHostComputeMlir"
+    // CHECK-SAME: -> tensor<2xf32>
+    // CHECK: return
+    // CHECK-SAME: tensor<2xf32>
+    %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0, host_mlir_module = "module  {\0A  func @host_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {\0A    %0 = \22tf.Identity\22(%arg0) {_xla_outside_compilation = \22cluster1\22} : (tensor<*xf32>) -> tensor<*xf32> \0A    return %0 : tensor<*xf32> \0A  } \0A} \0A"} : (tensor<2xf32>) -> tensor<*xf32>
+    return %0 : tensor<*xf32>
+  }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index eb91d43..02ac830 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -59,6 +59,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
 #include "tensorflow/core/framework/shape_inference.h"
@@ -725,6 +726,10 @@
   // yields.
   bool InferShapeForIfRegion(IfRegionOp op);
 
+  // Infers the shape of _XlaHostComputeMlir based on the host computation
+  // module.  Returns true if a return type was changed.
+  bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
+
   // Infers the shape of ops that create TensorList. Specifically,
   // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
   // refines the element shape if all tensors written to the list across all
@@ -886,6 +891,51 @@
   return changed;
 }
 
+bool ShapeInference::InferShapeForXlaHostComputeMlir(
+    _XlaHostComputeMlirOp host_compute_op) {
+  // Extract the module and function.
+  // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
+  // attribute is well formed, so we just return in case of an error in
+  // extracting the host function since it should never occur.
+  StringAttr host_module =
+      host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
+  if (host_module.getValue().empty()) return false;
+
+  mlir::OwningModuleRef module_for_func;
+  if (!tensorflow::DeserializeMlirModule(host_module.getValue().str(),
+                                         host_compute_op->getContext(),
+                                         &module_for_func)
+           .ok()) {
+    return false;
+  }
+
+  FuncOp func = module_for_func->lookupSymbol<FuncOp>("host_func");
+  if (!func) return false;
+
+  // Update/use input shapes for function.
+  FunctionType func_type = func.getType();
+  func.setType(FunctionType::get(func.getContext(),
+                                 host_compute_op.getOperandTypes(),
+                                 func_type.getResults()));
+
+  // Run shape inference on the function.
+  if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
+                                     {&func.getBody()}, 10)))
+    return false;
+  if (failed(InferShapeForFunctionReturnType(func))) return false;
+
+  bool changed = false;
+  // Use refined function return shape for XlaHostComputeMlirOp.
+  for (auto result :
+       zip(host_compute_op.getResults(), func.getType().getResults())) {
+    changed = RefineResultType(host_compute_op, std::get<0>(result),
+                               std::get<1>(result)) ||
+              changed;
+  }
+
+  return changed;
+}
+
 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
   DCOMMENT_OP(op, "Inferring shape for TensorList ");
   Value handle = op->getResult(0);
@@ -1199,6 +1249,10 @@
         while_region,
         while_region.body().front().getTerminator()->getOperandTypes());
 
+  if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
+    return InferShapeForXlaHostComputeMlir(host_compute_op);
+  }
+
   // Handle TensorList init operations by inferring shape from TensorList write
   // operations. If we are unable to refine element shape here, proceed to use
   // the InferenceContext below to get more precise shapes.