Remove MhloBufferizationState, it is not needed anymore.

Bufferization now has logic in the bufferization of ops to create a copy if the op
cannot deal with non-identity maps.

PiperOrigin-RevId: 441733148
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
index 1f1220e..4fc8a0b 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h
@@ -24,18 +24,6 @@
 namespace mlir {
 namespace mhlo {
 
-/// mhlo dialect analysis state. mhlo-specific bufferization options are
-/// stored in this state.
-struct MhloBufferizationState : public bufferization::DialectAnalysisState {
-  using EnforceIdentityMapFn = std::function<bool(Operation *)>;
-
-  /// If this function returns true for an op, copies will be inserted when
-  /// the lowering would otherwise lead to a memref with a non-identity map.
-  EnforceIdentityMapFn enforce_identity_map_fn = [](Operation *) {
-    return true;
-  };
-};
-
 /// Register the external models for bufferizing mhlo ops.
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
index c8eb426..5c1cf03 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
@@ -258,30 +258,6 @@
   return transformed_operand;
 }
 
-Value CreateCopy(mhlo::DynamicBroadcastInDimOp op, Value broadcasted,
-                 OpBuilder *b) {
-  MemRefType result_type = broadcasted.getType().cast<MemRefType>();
-  auto loc = op.getLoc();
-  SmallVector<Value, 4> dynamic_operands;
-  for (int i = 0; i < result_type.getRank(); ++i) {
-    if (!result_type.isDynamicDim(i)) continue;
-    auto index = b->createOrFold<arith::ConstantIndexOp>(loc, i);
-    Value size =
-        b->create<tensor::ExtractOp>(loc, op.output_dimensions(), index);
-    if (!size.getType().isIndex()) {
-      size = b->create<arith::IndexCastOp>(loc, b->getIndexType(), size);
-    }
-    dynamic_operands.push_back(size);
-  }
-  auto identity_map_memref =
-      MemRefType::get(result_type.getShape(), result_type.getElementType());
-  auto copy = b->create<memref::AllocOp>(op.getLoc(), identity_map_memref,
-                                         dynamic_operands);
-  b->create<memref::CopyOp>(loc, broadcasted, copy);
-
-  return copy;
-}
-
 struct DynamicBroadcastInDimOpInterface
     : public BufferizableOpInterface::ExternalModel<
           DynamicBroadcastInDimOpInterface, mhlo::DynamicBroadcastInDimOp> {
@@ -322,15 +298,6 @@
     Value result = InsertDynamicMemrefCastOp(broadcast_in_dim_op,
                                              *operand_buffer, &rewriter);
 
-    // Evaluate `enforce_identity_map_fn` and maybe create a copy.
-    Optional<const MhloBufferizationState *> dialect_state =
-        state.getAnalysisState().getDialectState<MhloBufferizationState>(
-            mhlo::MhloDialect::getDialectNamespace());
-    assert(dialect_state.hasValue() && "mhlo dialect state not initialized");
-    if ((*dialect_state)->enforce_identity_map_fn(op)) {
-      result = CreateCopy(broadcast_in_dim_op, result, &rewriter);
-    }
-
     bufferization::replaceOpWithBufferizedValues(rewriter, op, result);
     return success();
   }
@@ -349,10 +316,6 @@
     bufferization::BufferizationOptions options =
         bufferization::getPartialBufferizationOptions();
     options.allowDialectInFilter<mhlo::MhloDialect>();
-    // mhlo dialect state must be explicitly initialized to ease debugging.
-    options.addDialectStateInitializer(
-        mhlo::MhloDialect::getDialectNamespace(),
-        []() { return std::make_unique<MhloBufferizationState>(); });
     if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();
   }
 };
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
index fbf4196..987a6e4 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
@@ -25,10 +25,8 @@
 // CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
 
 // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[C1]], %[[C1]], %[[C1]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
-// CHECK: %[[ALLOC:.*]] = memref.alloc
-// CHECK: memref.copy %[[TRANSFORMED_MEMREF]], %[[ALLOC]] : memref<?x?x?xf32, #map> to memref<?x?x?xf32>
 
-// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
 
 // CHECK: return %[[RESULT]]
 
@@ -67,10 +65,8 @@
 // CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
 
 // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xi32> to memref<?x?x?xi32, #map>
-// CHECK: %[[ALLOC:.*]] = memref.alloc
-// CHECK: memref.copy %[[TRANSFORMED_MEMREF]], %[[ALLOC]] : memref<?x?x?xi32, #map> to memref<?x?x?xi32>
 
-// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
 
 // CHECK: return %[[RESULT]]
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index f531602..06dcc8c 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -150,21 +150,6 @@
     options.denyOperationInFilter([](Operation* op) {
       return mlir::isa<gml_st::LoopOp>(op->getParentOp());
     });
-    // Configure bufferization options for mhlo ops.
-    options.addDialectStateInitializer(
-        mhlo::MhloDialect::getDialectNamespace(), []() {
-          auto dialect_state = std::make_unique<mhlo::MhloBufferizationState>();
-          dialect_state->enforce_identity_map_fn = [](Operation* op) {
-            // Force identity maps for several ops which don't support memrefs
-            // with affine_maps.
-            return llvm::any_of(op->getUsers(), [](Operation* user) {
-              return isa<gml_st::LoopOp, func::ReturnOp, mhlo::DynamicReshapeOp,
-                         tensor::CastOp, tensor::CollapseShapeOp,
-                         tensor::ExpandShapeOp>(user);
-            });
-          };
-          return dialect_state;
-        });
 
     if (failed(bufferization::bufferizeOp(getOperation(), options))) {
       signalPassFailure();