Support hlo to lhlo buffer placement through shape.assuming ops.

PiperOrigin-RevId: 336287728
Change-Id: I8975254382caff1091cd8e9aef8730d506433278
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index bc62ccc..30e8d8a 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -540,6 +540,8 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Shape",
+        "@llvm-project//mlir:ShapeTransforms",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index 3485aff..22338d2 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -20,6 +20,8 @@
 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
@@ -448,6 +450,10 @@
       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
                          isMemRefType);
     });
+    target.addDynamicallyLegalOp<shape::AssumingOp>([&](shape::AssumingOp op) {
+      return std::all_of(op.result_type_begin(), op.result_type_end(),
+                         isMemRefType);
+    });
 
     auto kind = results_escape_function
                     ? BufferAssignmentTypeConverter::KeepAsFunctionResult
@@ -460,6 +466,7 @@
     populateWithBufferAssignmentOpConversionPatterns<
         mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
                                                        &patterns);
+    populateShapeTypeConversionPatterns(&context, &converter, &patterns);
     if (failed(applyPartialConversion(getOperation(), target, patterns)))
       signalPassFailure();
   }
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
index 75e5c1b..3caa4f0 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
@@ -612,3 +612,24 @@
   tensor_store %result_tensor, %result: memref<2x2xi1>
   return
 }
+
+// -----
+
+// Test that assuming ops propagate memref types.
+// BOTH-LABEL: func @shape_assuming_memref
+func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
+  %1 = shape.const_witness true
+  // BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
+  %2 = shape.assuming %1 -> (tensor<?xf16>) {
+    %3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
+    %4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
+    %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
+    %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
+    // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
+    %7 = mhlo.maximum %5, %6 : tensor<?xf16>
+    // BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
+    shape.assuming_yield %7 : tensor<?xf16>
+  }
+  return %2 : tensor<?xf16>
+}
\ No newline at end of file