Added test cases for new BufferAssignment analysis.
diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
new file mode 100644
index 0000000..8cf8592
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
@@ -0,0 +1,131 @@
+// RUN: tf-opt -test-buffer-assignment -split-input-file %s -o - 2>&1 | FileCheck %s -dump-input-on-failure
+
+// CHECK-LABEL: Testing : condBranch
+func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: cond_br
+    cond_br %cond, ^bb1, ^bb2
+  ^bb1:
+    br ^exit(%arg0 : tensor<2xf32>)
+  ^bb2:
+    %1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    br ^exit(%1 : tensor<2xf32>)
+  ^exit(%arg1: tensor<2xf32>):
+    return %arg1 : tensor<2xf32>
+  // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : criticalEdge
+func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: cond_br
+    cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
+  ^bb1:
+    %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    br ^exit(%0 : tensor<2xf32>)
+  ^exit(%arg1: tensor<2xf32>):
+    return %arg1 : tensor<2xf32>
+  // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : invCriticalEdge
+func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: %0 = "xla_hlo.exp"
+    %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
+  ^bb1:
+    br ^exit(%0 : tensor<2xf32>)
+  ^exit(%arg1: tensor<2xf32>):
+    return %arg1 : tensor<2xf32>
+  // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElse
+func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+    %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+  ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+    br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+  ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+    br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
+  ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+  // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5)
+  // CHECK: Alloc: %7 = "xla_hlo.exp"(%5)
+  // CHECK-NEXT: Dealloc: return
+    %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
+    return %1 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElseNoUsers
+func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+    %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+  ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+    br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+  ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+    br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
+  ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+  // CHECK-NEXT: return
+    return %arg0 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElseNested
+func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+  // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+    %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+    cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+  ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+    br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+  ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+    cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>)
+  ^bb3(%arg7 : tensor<2xf32>):
+    br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>)
+  ^bb4(%arg8 : tensor<2xf32>):
+    br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>)
+  ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+  // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7)
+  // CHECK: Alloc: %9 = "xla_hlo.exp"(%7)
+  // CHECK-NEXT: Dealloc: return
+    %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
+    return %1 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : redundantOperations
+func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) {
+  // CHECK: Alloc: %0 = xla_hlo.max
+  // CHECK-NEXT: Dealloc: %1 = xla_hlo.add
+  %1 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: Alloc: %1 = xla_hlo.add
+  // CHECK-NEXT: Dealloc: %1 = xla_hlo.add
+  %2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : reduce
+func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK: Alloc: %0 = xla_hlo.constant
+  // CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0)
+  %0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+  // CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0)
+  // CHECK: Dealloc: return
+  %2 = "xla_hlo.reduce"(%arg0, %0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+    %4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+    "xla_hlo.return"(%4) : (tensor<f32>) -> ()
+  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+  return %2 : tensor<4x8xf32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
index 81df3e6..bd9d02f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
@@ -260,5 +260,22 @@
         }
 }
 
+/// A simple pass to print debug/test information for the buffer assignment
+/// analysis.
+struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
+  void runOnFunction() override {
+    llvm::errs() << "Testing : " << getFunction().getName() << "\n";
+    getAnalysis<BufferAssignment>().print(llvm::errs());
+  };
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
+  return absl::make_unique<BufferAssignmentTestPass>();
+}
+
+static PassRegistration<BufferAssignmentTestPass> buffer_assignment_test_pass(
+    "test-buffer-assignment",
+    "Outputs debug test information for the buffer assignment analysis");
+
 }  // namespace xla
 }  // namespace mlir