Generalize tf-rewrite-tpu-embedding-ops pass to handle embedding ops within regions

Currently, the pass assumes embedding ops to be directly in a function and was working only because functional control flow to regions pass introduces call functions so the embedding ops ended up being in a separate function.

The above won't be true for this pass's usage to rewrite these ops in MLIR module before converting it to Graph for compilation using the old bridge second phase.

PiperOrigin-RevId: 362586719
Change-Id: Ieaabcc0db8a44f6c721d1842da0720e8d9389797
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
index f5bfb3e..01591ca 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/rewrite_tpu_embedding_ops.mlir
@@ -40,3 +40,36 @@
   %0 = "tf.Add"(%arg0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   return %0 : tensor<2x2xf32>
 }
+
+// CHECK-LABEL: func @nested_embedding_op
+func @nested_embedding_op(%arg0: tensor<i1>, %arg1: tensor<512x256xf32>) -> (tensor<512x256xf32>) {
+  %1 = "tf.IfRegion"(%arg0) ({
+    // CHECK: "tf._RecvTPUEmbeddingDeduplicationData"
+    // CHECK: "tf._RecvTPUEmbeddingActivations"
+    // CHECK-NOT: tf.RecvTPUEmbeddingActivations
+    %0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
+    "tf.Yield"(%0) : (tensor<512x256xf32>) -> ()
+  }, {
+    "tf.Yield"(%arg1) : (tensor<512x256xf32>) -> ()
+  }) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
+  return %1 : tensor<512x256xf32>
+}
+
+// CHECK-LABEL: func @doubly_nested_embedding_op
+func @doubly_nested_embedding_op(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<512x256xf32>) -> (tensor<512x256xf32>) {
+  %2 = "tf.IfRegion"(%arg0) ({
+    %1 = "tf.IfRegion"(%arg1) ({
+      // CHECK: "tf._RecvTPUEmbeddingDeduplicationData"
+      // CHECK: "tf._RecvTPUEmbeddingActivations"
+      // CHECK-NOT: tf.RecvTPUEmbeddingActivations
+      %0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
+      "tf.Yield"(%0) : (tensor<512x256xf32>) -> ()
+    }, {
+      "tf.Yield"(%arg2) : (tensor<512x256xf32>) -> ()
+    }) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
+    "tf.Yield"(%1) : (tensor<512x256xf32>) -> ()
+  }, {
+    "tf.Yield"(%arg2) : (tensor<512x256xf32>) -> ()
+  }) { is_stateless = true}: (tensor<i1>) -> tensor<512x256xf32>
+  return %2 : tensor<512x256xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
index c2b8a07..6551a9b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc
@@ -49,33 +49,32 @@
 // assigns it to `result`, if present. If there are multiple such ops, returns
 // failure.
 template <typename OpT>
-LogicalResult GetOp(FuncOp func, OpT* result) {
+LogicalResult GetOp(Region* region, OpT* result) {
   *result = {};
-  for (auto op : func.getOps<OpT>()) {
+  for (auto op : region->getOps<OpT>()) {
     if (*result) return op.emitError("should be unique within a function");
     *result = op;
   }
   return success();
 }
 
-void RewriteTPUEmbeddingOps::runOnFunction() {
-  FuncOp func = getFunction();
-
+LogicalResult RunOnRegion(Region* region) {
   RecvTPUEmbeddingActivationsOp recv_op;
-  if (failed(GetOp(func, &recv_op))) return signalPassFailure();
+  if (failed(GetOp(region, &recv_op))) return failure();
 
   SendTPUEmbeddingGradientsOp send_op;
-  if (failed(GetOp(func, &send_op))) return signalPassFailure();
+  if (failed(GetOp(region, &send_op))) return failure();
 
   // No TPU embedding ops.
-  if (!recv_op && !send_op) return;
+  if (!recv_op && !send_op) return success();
 
   Location loc = recv_op ? recv_op.getLoc() : send_op.getLoc();
   StringRef config = recv_op ? recv_op.config() : send_op.config();
 
   // Create _RecvTPUEmbeddingDeduplicationData op.
-  OpBuilder builder(func.getBody());
-  auto output_ty = RankedTensorType::get({}, VariantType::get(&getContext()));
+  OpBuilder builder(region);
+  auto output_ty =
+      RankedTensorType::get({}, VariantType::get(region->getContext()));
   auto dedup_op = builder.create<_RecvTPUEmbeddingDeduplicationDataOp>(
       loc, output_ty, config);
 
@@ -97,6 +96,18 @@
     new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(),
                          operand_size_attr);
   }
+  return success();
+}
+
+void RewriteTPUEmbeddingOps::runOnFunction() {
+  FuncOp func = getFunction();
+  if (failed(RunOnRegion(&func.getBody()))) return signalPassFailure();
+
+  func.walk([&](Operation* op) {
+    for (Region& region : op->getRegions()) {
+      if (failed(RunOnRegion(&region))) return signalPassFailure();
+    }
+  });
 }
 
 }  // anonymous namespace