Change RedundantCopiesRemoval to a pass
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 28c5af3..538b392 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -45,6 +45,19 @@
   // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
 }
 
+// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
+func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
+  %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  tensor_store %0, %arg2 : memref<f32>  
+  return 
+}
+// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
+// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
+// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
+// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
+// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
+// CHECK: "xla_lhlo.terminator"() : () -> ()
+
 // CHECK-LABEL: func @fusion
 func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
              %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -208,4 +221,4 @@
   // CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
   tensor_store %tensor_result, %result : memref<2x2xf32>
   return
-}
+}
\ No newline at end of file
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index 2baa06a..835f87f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
 #include "mlir/Pass/Pass.h"  // TF:llvm-project
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"  // TF:llvm-project
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -49,17 +50,6 @@
   return nullptr;
 }
 
-Value GetTensorStore(Value value) {
-  for (const auto& user : value.getUsers()) {
-    if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
-      if (tensor_store.getOperand(0) == value) {
-        return tensor_store.getOperand(1);
-      }
-    }
-  }
-  return nullptr;
-}
-
 Value InsertAllocAndDealloc(Location loc, Value result,
                             ConversionPatternRewriter* rewriter) {
   auto result_type = result.getType().dyn_cast<ShapedType>();
@@ -85,17 +75,6 @@
   return alloc;
 }
 
-/// For every tensor-type value that is produced in the original function,
-/// this function returns the buffer that can be used in the converted
-/// function to store that values held in the tensor.
-Value GetBufferForResultValue(Location loc, Value result,
-                              ConversionPatternRewriter* rewriter) {
-  if (auto existing_memref = GetTensorStore(result)) {
-    return existing_memref;
-  }
-  return InsertAllocAndDealloc(loc, result, rewriter);
-}
-
 template <typename HloOpTy, typename LhloOpTy>
 class HloToLhloOpConverter : public ConversionPattern {
  public:
@@ -137,7 +116,7 @@
     const auto& original_results = op.getResults();
     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
     for (auto result : original_results) {
-      buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
+      buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
     }
     auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
         loc, llvm::None, buffer_args, op.getAttrs());
@@ -200,38 +179,6 @@
   }
 };
 
-/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
-/// argument. All uses of the buffer are replaced with the block argument.
-void RemoveRedundantCopies(ModuleOp module) {
-  llvm::SmallVector<Operation*, 2> eraseList;
-  module.walk([&](xla_lhlo::CopyOp copyOp) {
-    auto arguments = copyOp.getOperation()->getBlock()->getArguments();
-    if (std::any_of(
-            arguments.begin(), arguments.end(),
-            [&](BlockArgument arg) { return copyOp.output() == arg; }) &&
-        std::none_of(
-            arguments.begin(), arguments.end(),
-            [&](BlockArgument arg) { return copyOp.operand() == arg; })) {
-      Value operand = copyOp.operand();
-      Value output = copyOp.output();
-      copyOp.erase();
-      for (auto op : operand.getUsers()) {
-        if (!isa<DeallocOp>(op)) {
-          op->replaceUsesOfWith(operand, output);
-        }
-      }
-      auto allocOp = operand.getDefiningOp();
-      if (auto deallocOp = dyn_cast<DeallocOp>(*allocOp->getUsers().begin())) {
-        eraseList.push_back(deallocOp);
-        eraseList.push_back(allocOp);
-      }
-    }
-  });
-  for (auto op : eraseList) {
-    op->erase();
-  }
-}
-
 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 // buffers if necessary.
 //
@@ -321,8 +268,6 @@
     if (failed(applyFullConversion(module, target, patterns, nullptr))) {
       signalPassFailure();
     }
-
-    RemoveRedundantCopies(module);
   }
 };
 
@@ -442,12 +387,57 @@
   // clang-format on
 }
 
+/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
+/// argument. All uses of the buffer are replaced with the block argument.
+struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
+  void runOnFunction() override {
+    llvm::SmallVector<mlir::Operation*, 2> eraseList;
+    getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
+      auto arguments = copyOp.getOperation()->getBlock()->getArguments();
+      if (std::any_of(arguments.begin(), arguments.end(),
+                      [&](mlir::BlockArgument arg) {
+                        return copyOp.output() == arg;
+                      }) &&
+          std::none_of(arguments.begin(), arguments.end(),
+                       [&](mlir::BlockArgument arg) {
+                         return copyOp.operand() == arg;
+                       })) {
+        mlir::Value operand = copyOp.operand();
+        mlir::Value output = copyOp.output();
+        copyOp.erase();
+        for (auto op : operand.getUsers()) {
+          if (!mlir::isa<mlir::DeallocOp>(op)) {
+            op->replaceUsesOfWith(operand, output);
+          }
+        }
+        auto allocOp = operand.getDefiningOp();
+        if (auto deallocOp =
+                mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
+          eraseList.push_back(deallocOp);
+          eraseList.push_back(allocOp);
+        }
+      }
+    });
+    for (auto op : eraseList) {
+      op->erase();
+    }
+  };
+};
+
 std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
   return absl::make_unique<HloLegalizeToLhlo>();
 }
 
-static PassRegistration<HloLegalizeToLhlo> legalize_pass(
-    "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
+std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
+  return absl::make_unique<RedundantCopiesRemoval>();
+}
+
+static PassPipelineRegistration<> legalize_pass(
+    "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect",
+    [](mlir::OpPassManager& pm) {
+      pm.addPass(createLegalizeToLhloPass());
+      pm.addPass(createLhloCopyRemovalPass());
+    });
 
 }  // namespace xla_hlo
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h
index c890a81..14a2e33 100644
--- a/tensorflow/compiler/mlir/xla/transforms/passes.h
+++ b/tensorflow/compiler/mlir/xla/transforms/passes.h
@@ -58,6 +58,11 @@
 // Lowers from HLO dialect to Linalg dialect.
 std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();
 
+// Removes unnecessary LHLO copies which copy from the allocated buffers to the
+// block arguments. These copies have been created by replacing TensorStoreOp
+// with LHLO.CopyOp in HLO to LHLO lowering.
+std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();
+
 }  // namespace xla_hlo
 
 namespace xla_lhlo {
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 76ad2b3..79e56e4 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -276,6 +276,8 @@
   pm.addPass(absl::make_unique<FusionToLhloConverter>());
   // Next, we can strip the outer fusion operation.
   pm.addPass(absl::make_unique<FusionOpRemover>());
+  // Remove unnecessary Lhlo copies.
+  pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
   // Transform lhlo operations to LinAlg.
   pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
   // Fuse linalg operations. This will yield a single tiled loop nest where