Added test cases for new BufferAssignment analysis.
diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
new file mode 100644
index 0000000..8cf8592
--- /dev/null
+++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir
@@ -0,0 +1,131 @@
+// RUN: tf-opt -test-buffer-assignment -split-input-file %s -o - 2>&1 | FileCheck %s -dump-input-on-failure
+
+// CHECK-LABEL: Testing : condBranch
+func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: cond_br
+ cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ br ^exit(%arg0 : tensor<2xf32>)
+ ^bb2:
+ %1 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ br ^exit(%1 : tensor<2xf32>)
+ ^exit(%arg1: tensor<2xf32>):
+ return %arg1 : tensor<2xf32>
+ // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : criticalEdge
+func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: cond_br
+ cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
+ ^bb1:
+ %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ br ^exit(%0 : tensor<2xf32>)
+ ^exit(%arg1: tensor<2xf32>):
+ return %arg1 : tensor<2xf32>
+ // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : invCriticalEdge
+func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: %0 = "xla_hlo.exp"
+ %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>)
+ ^bb1:
+ br ^exit(%0 : tensor<2xf32>)
+ ^exit(%arg1: tensor<2xf32>):
+ return %arg1 : tensor<2xf32>
+ // CHECK-NEXT: Dealloc: return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElse
+func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+ %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+ ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+ br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+ ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+ br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
+ ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+ // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exp"(%5)
+ // CHECK: Alloc: %7 = "xla_hlo.exp"(%5)
+ // CHECK-NEXT: Dealloc: return
+ %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
+ return %1 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElseNoUsers
+func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+ %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+ ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+ br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+ ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+ br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
+ ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+ // CHECK-NEXT: return
+ return %arg0 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : ifElseNested
+func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{
+ // CHECK: Alloc: %0 = "xla_hlo.exp"(%arg1)
+ %0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
+ ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
+ br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
+ ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
+ cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>)
+ ^bb3(%arg7 : tensor<2xf32>):
+ br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>)
+ ^bb4(%arg8 : tensor<2xf32>):
+ br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>)
+ ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
+ // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exp"(%7)
+ // CHECK: Alloc: %9 = "xla_hlo.exp"(%7)
+ // CHECK-NEXT: Dealloc: return
+ %1 = "xla_hlo.exp"(%arg5) : (tensor<2xf32>) -> tensor<2xf32>
+ return %1 : tensor<2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: Testing : redundantOperations
+func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) {
+ // CHECK: Alloc: %0 = xla_hlo.max
+ // CHECK-NEXT: Dealloc: %1 = xla_hlo.add
+ %1 = "xla_hlo.max"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ // CHECK: Alloc: %1 = xla_hlo.add
+ // CHECK-NEXT: Dealloc: %1 = xla_hlo.add
+ %2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: Testing : reduce
+func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK: Alloc: %0 = xla_hlo.constant
+ // CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0)
+ %0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0)
+ // CHECK: Dealloc: return
+ %2 = "xla_hlo.reduce"(%arg0, %0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
+ %4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
+ "xla_hlo.return"(%4) : (tensor<f32>) -> ()
+ }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+ return %2 : tensor<4x8xf32>
+}
diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
index 81df3e6..bd9d02f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc
@@ -260,5 +260,22 @@
}
}
+/// A simple pass to print debug/test information for the buffer assignment
+/// analysis.
+struct BufferAssignmentTestPass : mlir::FunctionPass<BufferAssignmentTestPass> {
+ void runOnFunction() override {
+ llvm::errs() << "Testing : " << getFunction().getName() << "\n";
+ getAnalysis<BufferAssignment>().print(llvm::errs());
+ };
+};
+
+std::unique_ptr<OpPassBase<FuncOp>> createBufferAssignmentTestPass() {
+ return absl::make_unique<BufferAssignmentTestPass>();
+}
+
+static PassRegistration<BufferAssignmentTestPass> buffer_assignment_test_pass(
+ "test-buffer-assignment",
+ "Outputs debug test information for the buffer assignment analysis");
+
} // namespace xla
} // namespace mlir