Refactor TPUExtractOutsideCompilation pass.

Extracts code to finer grained methods.  Passes multiple ops into some methods as preparation for handling ops with dynamic shapes that require capturing the function for shape inference.

PiperOrigin-RevId: 368037987
Change-Id: I77bdbaad5e5ab0a0a7b6c2dd781d1d4d39e072ff
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 3454836..59eb23e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -196,37 +196,62 @@
   return launch_op;
 }
 
-llvm::SmallSetVector<Value, 4> GetExternalOperands(
-    tf_device::ClusterOp tpu_cluster, Operation* op) {
+// Returns operands of `cluster_ops` that need to be
+// communicated from device->host. This is for the case when all operands have a
+// static shape.
+llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
+    tf_device::ClusterOp tpu_cluster,
+    const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
   llvm::SmallSetVector<Value, 4> external_values;
-  op->walk([&](Operation* walked_op) {
-    if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(walked_op))
-      return WalkResult::advance();
-    for (Value v : walked_op->getOperands()) {
-      if (auto* defining_op = v.getDefiningOp()) {
-        if (!op->isAncestor(defining_op) &&
-            tpu_cluster->isAncestor(defining_op) &&
-            !HasOutsideCompilationAncestor(defining_op) &&
-            !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
-          external_values.insert(v);
+  for (Operation* op : cluster_ops) {
+    op->walk([&](Operation* walked_op) {
+      if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
+              walked_op))
+        return WalkResult::advance();
+      for (Value v : walked_op->getOperands()) {
+        if (auto* defining_op = v.getDefiningOp()) {
+          if (!op->isAncestor(defining_op) &&
+              tpu_cluster->isAncestor(defining_op) &&
+              !HasOutsideCompilationAncestor(defining_op) &&
+              !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
+            external_values.insert(v);
+          }
+          continue;
         }
-        continue;
+        auto block_arg = v.cast<BlockArgument>();
+        if (block_arg.getParentRegion() == op->getParentRegion())
+          external_values.insert(v);
       }
-      auto block_arg = v.cast<BlockArgument>();
-      if (block_arg.getParentRegion() == op->getParentRegion())
-        external_values.insert(v);
-    }
-    return WalkResult::advance();
-  });
+      return WalkResult::advance();
+    });
+  }
   return external_values;
 }
 
-llvm::SmallSetVector<Value, 4> GetExternalOutputs(Operation* op) {
+// Returns a SmallSetVector containing all of the operands that need to be
+// communicated from device->host.
+llvm::SmallSetVector<Value, 4> GetExternalOperands(
+    tf_device::ClusterOp tpu_cluster,
+    const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
+  return GetStaticExternalOperands(tpu_cluster, cluster_ops);
+}
+
+// Gets all outputs that need to be communicated from host->device.
+llvm::SmallSetVector<Value, 4> GetExternalOutputs(
+    const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
   llvm::SmallSetVector<Value, 4> external_outputs;
-  for (Operation* user : op->getUsers()) {
-    if (!HasOutsideCompilationAncestor(user)) {
-      for (Value v : user->getOperands()) {
-        if (v.getDefiningOp() == op) external_outputs.insert(v);
+  for (Operation* op : cluster_ops) {
+    for (Operation* user : op->getUsers()) {
+      // We skip any operations that are in the same outside compilation
+      // cluster that will be moved to the host at the same time since both
+      // defining op and user op will be moved to host.
+      if (cluster_ops.count(user)) {
+        continue;
+      }
+      if (!HasOutsideCompilationAncestor(user)) {
+        for (Value v : user->getOperands()) {
+          if (v.getDefiningOp() == op) external_outputs.insert(v);
+        }
       }
     }
   }
@@ -236,19 +261,20 @@
 // Creates the HostCompute with `inputs` and `outputs`
 // using `communication_key`.
 TF::_XlaHostComputeMlirOp CreateHostCompute(
-    OpBuilder& builder, Operation* loc_op,
+    OpBuilder& builder, Location loc,
     const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
     llvm::StringRef args_communication_key,
-    llvm::StringRef retvals_communication_key) {
+    llvm::StringRef retvals_communication_key,
+    llvm::StringRef serialized_func_module) {
   llvm::SmallVector<Type, 4> device_output_types;
   for (const auto& output : outputs)
     device_output_types.push_back(output.getType());
   auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
-      loc_op->getLoc(), device_output_types, inputs.getArrayRef(),
+      loc, device_output_types, inputs.getArrayRef(),
       builder.getStringAttr(args_communication_key),
       builder.getStringAttr(retvals_communication_key),
       /*tpu_core=*/builder.getI64IntegerAttr(0),
-      /*host_mlir_module=*/builder.getStringAttr(""));
+      /*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
   return host_compute;
 }
 
@@ -257,6 +283,121 @@
               StringAttr::get(op->getContext(), "temp"));
 }
 
+// Replaces `external_operands` with the results from `recv_at_host`.
+// For static-shapes, Replace operand usages if op is in the same region as
+// insertion or if the op is outside compiled and will be moved to host later.
+void ReplaceExternalOperandUsage(
+    const llvm::SmallSetVector<Value, 4>& external_operands,
+    Operation* recv_at_host, Operation* insertion_point,
+    Block* original_op_block) {
+  auto replace_operand_usage = [&](OpOperand& operand) {
+    return insertion_point->getParentRegion()->isAncestor(
+               operand.getOwner()->getParentRegion()) ||
+           (HasOutsideCompilationAncestor(operand.getOwner()) &&
+            original_op_block == operand.getOwner()->getBlock());
+  };
+  for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
+    Value external_operand = std::get<0>(result);
+    external_operand.replaceUsesWithIf(std::get<1>(result),
+                                       replace_operand_usage);
+  }
+}
+
+// Replaces usages of `external_outputs` which are values returned by outside
+// compilation with the corresponding outputs from `host_compute`.
+void ReplaceExternalOutputUsage(
+    const llvm::SmallSetVector<Value, 4>& external_outputs,
+    TF::_XlaHostComputeMlirOp host_compute) {
+  auto replace_output_usage = [&](OpOperand& operand) {
+    // Don't replace output usages in host computation or for outside
+    // compiled ops.
+    return !operand.get().getDefiningOp()->getParentRegion()->isAncestor(
+               operand.getOwner()->getParentRegion()) &&
+           !HasOutsideCompilationAncestor(operand.getOwner());
+  };
+  for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
+    Value external_output = std::get<0>(result);
+    external_output.replaceUsesWithIf(std::get<1>(result),
+                                      replace_output_usage);
+  }
+}
+
+// Move `clustered_ops` to run on host and adds communication ops to transfer
+// `external_operands` and `external_outputs` to/from device/host.  Inserts
+// ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
+// creating comm ops.
+void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
+                   const llvm::SmallSetVector<Value, 4>& external_operands,
+                   const llvm::SmallSetVector<Value, 4>& external_outputs,
+                   Operation* insertion_point, Value compilation_key,
+                   Value device_ordinal, int& communication_key_index) {
+  OpBuilder builder(insertion_point);
+  Operation& op = *clustered_ops.back();
+  std::string args_communication_key =
+      llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
+          .str();
+  std::string retvals_communication_key =
+      llvm::formatv("host_compute_channel_{0}_retvals",
+                    (communication_key_index))
+          .str();
+
+  // Use a unique name when sending just the IfRegion predicate.  This is
+  // for readable and to match the key in the TF2XLA bridge.
+  if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
+      external_operands.size() == 1) {
+    args_communication_key =
+        llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
+            .str();
+  }
+
+  std::string serialized_func_module;
+
+  builder.setInsertionPoint(&op);
+  auto host_compute =
+      CreateHostCompute(builder, op.getLoc(), external_operands,
+                        external_outputs.getArrayRef(), args_communication_key,
+                        retvals_communication_key, serialized_func_module);
+  // Insert ops on the host side computation to receive data from device.
+  builder.setInsertionPoint(insertion_point);
+  llvm::SmallVector<Type, 4> host_operand_types;
+  for (const auto& operand : external_operands)
+    host_operand_types.push_back(operand.getType());
+
+  Operation* recv_at_host = CreateRecvAtHostOp(
+      builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
+      args_communication_key);
+  Block* original_op_block = op.getBlock();
+  Operation* after_op = recv_at_host;
+  for (Operation* cluster_op : clustered_ops) {
+    cluster_op->moveAfter(after_op);
+    cluster_op->removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
+    after_op = cluster_op;
+  }
+
+  if (!external_outputs.empty()) {
+    CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
+                         compilation_key, device_ordinal,
+                         retvals_communication_key);
+  }
+
+  if (external_operands.empty()) {
+    recv_at_host->erase();
+  } else {
+    ReplaceExternalOperandUsage(external_operands,
+                                /*recv_at_host=*/recv_at_host,
+                                /*insertion_point=*/insertion_point,
+                                /*original_op_block=*/original_op_block);
+  }
+
+  ReplaceExternalOutputUsage(external_outputs, host_compute);
+
+  if (external_operands.empty() && external_outputs.empty()) {
+    host_compute.erase();
+  } else {
+    ++communication_key_index;
+  }
+}
+
 // Move outside compiled ops in `src` to to `insertion_point` in host
 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
 // call to this method).  Communication ops are added in both `src` and at
@@ -268,17 +409,26 @@
                             Operation* insertion_point, Value compilation_key,
                             Value device_ordinal,
                             int& communication_key_index) {
-  OpBuilder builder(insertion_point);
+  // Contains all of the outside compiled operations that should be moved to the
+  // host using a single `_XlaHostComputeMlir` op.  This should only contain a
+  // single op except in the case where some of the input/output shapes are
+  // non-static.
+  llvm::SmallSetVector<Operation*, 4> clustered_ops;
+
   for (Operation& op : llvm::make_early_inc_range(*src)) {
     if (HasOutsideCompilationAncestorExclusive(&op) ||
         !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
       continue;
 
+    clustered_ops.insert(&op);
+
     // Get the operands and outputs that need to be communicated between host
     // and device.  External operands are from device -> host and external
     // outputs are from host -> device.
-    auto external_operands = GetExternalOperands(tpu_cluster, &op);
-    auto external_outputs = GetExternalOutputs(&op);
+    llvm::SmallSetVector<Value, 4> external_operands =
+        GetExternalOperands(tpu_cluster, clustered_ops);
+    llvm::SmallSetVector<Value, 4> external_outputs =
+        GetExternalOutputs(clustered_ops);
 
     // Check if any of the outside compiled input/output shapes can be refined.
     for (const auto& operand : external_operands) {
@@ -294,75 +444,10 @@
             "not currently supported.  See b/177523289.");
     }
 
-    builder.setInsertionPoint(&op);
-    std::string args_communication_key =
-        llvm::formatv("host_compute_channel_{0}_args",
-                      (communication_key_index))
-            .str();
-    if (llvm::isa<TF::IfRegionOp>(op) && external_operands.size() == 1) {
-      args_communication_key =
-          llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
-              .str();
-    }
-    std::string retvals_communication_key =
-        llvm::formatv("host_compute_channel_{0}_retvals",
-                      (communication_key_index))
-            .str();
-    auto host_compute = CreateHostCompute(
-        builder, &op, external_operands, external_outputs.getArrayRef(),
-        args_communication_key, retvals_communication_key);
-    // Insert ops on the host side computation to receive data from device.
-    builder.setInsertionPoint(insertion_point);
-    llvm::SmallVector<Type, 4> host_operand_types;
-    for (const auto& operand : external_operands)
-      host_operand_types.push_back(operand.getType());
-
-    auto recv_at_host = CreateRecvAtHostOp(
-        builder, op.getLoc(), host_operand_types, compilation_key,
-        device_ordinal, args_communication_key);
-    auto original_op_block = op.getBlock();
-    op.moveAfter(recv_at_host);
-    op.removeAttr(Identifier::get(kDeviceAttr, op.getContext()));
-    if (!external_outputs.empty()) {
-      CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
-                           compilation_key, device_ordinal,
-                           retvals_communication_key);
-    }
-    // Replace operand usages if op is in the same region as insertion or if
-    // the op is outside compiled and will be moved to host later.
-    auto replace_operand_usage = [&](OpOperand& operand) {
-      return insertion_point->getParentRegion()->isAncestor(
-                 operand.getOwner()->getParentRegion()) ||
-             (HasOutsideCompilationAncestor(operand.getOwner()) &&
-              original_op_block == operand.getOwner()->getBlock());
-    };
-    if (external_operands.empty()) {
-      recv_at_host->erase();
-    } else {
-      for (auto result :
-           llvm::zip(external_operands, recv_at_host->getResults())) {
-        Value external_operand = std::get<0>(result);
-        external_operand.replaceUsesWithIf(std::get<1>(result),
-                                           replace_operand_usage);
-      }
-    }
-    // Don't replace output usages in host computation or for outside
-    // compiled ops.
-    auto replace_output_usage = [&](OpOperand& operand) {
-      return !op.getParentRegion()->isAncestor(
-                 operand.getOwner()->getParentRegion()) &&
-             !HasOutsideCompilationAncestor(operand.getOwner());
-    };
-    for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
-      Value external_output = std::get<0>(result);
-      external_output.replaceUsesWithIf(std::get<1>(result),
-                                        replace_output_usage);
-    }
-    if (external_operands.empty() && external_outputs.empty()) {
-      host_compute.erase();
-    } else {
-      ++communication_key_index;
-    }
+    MoveOpsToHost(clustered_ops, external_operands, external_outputs,
+                  insertion_point, compilation_key, device_ordinal,
+                  communication_key_index);
+    clustered_ops.clear();
   }
   return success();
 }