Support hlo to lhlo buffer placement through shape.assuming ops.
PiperOrigin-RevId: 336287728
Change-Id: I8975254382caff1091cd8e9aef8730d506433278
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index bc62ccc..30e8d8a 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -540,6 +540,8 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Shape",
+ "@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index 3485aff..22338d2 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -20,6 +20,8 @@
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@@ -448,6 +450,10 @@
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType);
});
+ target.addDynamicallyLegalOp<shape::AssumingOp>([&](shape::AssumingOp op) {
+ return std::all_of(op.result_type_begin(), op.result_type_end(),
+ isMemRefType);
+ });
auto kind = results_escape_function
? BufferAssignmentTypeConverter::KeepAsFunctionResult
@@ -460,6 +466,7 @@
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
&patterns);
+ populateShapeTypeConversionPatterns(&context, &converter, &patterns);
if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
index 75e5c1b..3caa4f0 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir
@@ -612,3 +612,24 @@
tensor_store %result_tensor, %result: memref<2x2xi1>
return
}
+
+// -----
+
+// Test that assuming ops propagate memref types.
+// BOTH-LABEL: func @shape_assuming_memref
+func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
+ %1 = shape.const_witness true
+ // BOTH: shape.assuming %{{.*}} -> (memref<?xf16>)
+ %2 = shape.assuming %1 -> (tensor<?xf16>) {
+ %3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
+ %4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
+ %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
+ %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
+ // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
+ %7 = mhlo.maximum %5, %6 : tensor<?xf16>
+ // BOTH: shape.assuming_yield %{{.*}} : memref<?xf16>
+ shape.assuming_yield %7 : tensor<?xf16>
+ }
+ return %2 : tensor<?xf16>
+}
\ No newline at end of file