Remove MhloBufferizationState, it is not needed anymore.
Bufferization now has logic in the bufferization of ops to create a copy if the op
cannot deal with non-identity maps.
PiperOrigin-RevId: 441733148
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
index 1f1220e..4fc8a0b 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
@@ -24,18 +24,6 @@
namespace mlir {
namespace mhlo {
-/// mhlo dialect analysis state. mhlo-specific bufferization options are
-/// stored in this state.
-struct MhloBufferizationState : public bufferization::DialectAnalysisState {
- using EnforceIdentityMapFn = std::function<bool(Operation *)>;
-
- /// If this function returns true for an op, copies will be inserted when
- /// the lowering would otherwise lead to a memref with a non-identity map.
- EnforceIdentityMapFn enforce_identity_map_fn = [](Operation *) {
- return true;
- };
-};
-
/// Register the external models for bufferizing mhlo ops.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
index c8eb426..5c1cf03 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
@@ -258,30 +258,6 @@
return transformed_operand;
}
-Value CreateCopy(mhlo::DynamicBroadcastInDimOp op, Value broadcasted,
- OpBuilder *b) {
- MemRefType result_type = broadcasted.getType().cast<MemRefType>();
- auto loc = op.getLoc();
- SmallVector<Value, 4> dynamic_operands;
- for (int i = 0; i < result_type.getRank(); ++i) {
- if (!result_type.isDynamicDim(i)) continue;
- auto index = b->createOrFold<arith::ConstantIndexOp>(loc, i);
- Value size =
- b->create<tensor::ExtractOp>(loc, op.output_dimensions(), index);
- if (!size.getType().isIndex()) {
- size = b->create<arith::IndexCastOp>(loc, b->getIndexType(), size);
- }
- dynamic_operands.push_back(size);
- }
- auto identity_map_memref =
- MemRefType::get(result_type.getShape(), result_type.getElementType());
- auto copy = b->create<memref::AllocOp>(op.getLoc(), identity_map_memref,
- dynamic_operands);
- b->create<memref::CopyOp>(loc, broadcasted, copy);
-
- return copy;
-}
-
struct DynamicBroadcastInDimOpInterface
: public BufferizableOpInterface::ExternalModel<
DynamicBroadcastInDimOpInterface, mhlo::DynamicBroadcastInDimOp> {
@@ -322,15 +298,6 @@
Value result = InsertDynamicMemrefCastOp(broadcast_in_dim_op,
*operand_buffer, &rewriter);
- // Evaluate `enforce_identity_map_fn` and maybe create a copy.
- Optional<const MhloBufferizationState *> dialect_state =
- state.getAnalysisState().getDialectState<MhloBufferizationState>(
- mhlo::MhloDialect::getDialectNamespace());
- assert(dialect_state.hasValue() && "mhlo dialect state not initialized");
- if ((*dialect_state)->enforce_identity_map_fn(op)) {
- result = CreateCopy(broadcast_in_dim_op, result, &rewriter);
- }
-
bufferization::replaceOpWithBufferizedValues(rewriter, op, result);
return success();
}
@@ -349,10 +316,6 @@
bufferization::BufferizationOptions options =
bufferization::getPartialBufferizationOptions();
options.allowDialectInFilter<mhlo::MhloDialect>();
- // mhlo dialect state must be explicitly initialized to ease debugging.
- options.addDialectStateInitializer(
- mhlo::MhloDialect::getDialectNamespace(),
- []() { return std::make_unique<MhloBufferizationState>(); });
if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();
}
};
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
index fbf4196..987a6e4 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
@@ -25,10 +25,8 @@
// CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[C1]], %[[C1]], %[[C1]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
-// CHECK: %[[ALLOC:.*]] = memref.alloc
-// CHECK: memref.copy %[[TRANSFORMED_MEMREF]], %[[ALLOC]] : memref<?x?x?xf32, #map> to memref<?x?x?xf32>
-// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
// CHECK: return %[[RESULT]]
@@ -67,10 +65,8 @@
// CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xi32> to memref<?x?x?xi32, #map>
-// CHECK: %[[ALLOC:.*]] = memref.alloc
-// CHECK: memref.copy %[[TRANSFORMED_MEMREF]], %[[ALLOC]] : memref<?x?x?xi32, #map> to memref<?x?x?xi32>
-// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
// CHECK: return %[[RESULT]]
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index f531602..06dcc8c 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -150,21 +150,6 @@
options.denyOperationInFilter([](Operation* op) {
return mlir::isa<gml_st::LoopOp>(op->getParentOp());
});
- // Configure bufferization options for mhlo ops.
- options.addDialectStateInitializer(
- mhlo::MhloDialect::getDialectNamespace(), []() {
- auto dialect_state = std::make_unique<mhlo::MhloBufferizationState>();
- dialect_state->enforce_identity_map_fn = [](Operation* op) {
- // Force identity maps for several ops which don't support memrefs
- // with affine_maps.
- return llvm::any_of(op->getUsers(), [](Operation* user) {
- return isa<gml_st::LoopOp, func::ReturnOp, mhlo::DynamicReshapeOp,
- tensor::CastOp, tensor::CollapseShapeOp,
- tensor::ExpandShapeOp>(user);
- });
- };
- return dialect_state;
- });
if (failed(bufferization::bufferizeOp(getOperation(), options))) {
signalPassFailure();