Fix correctness corner cases in space-to-batch transformation
PiperOrigin-RevId: 358218872
Change-Id: Ic2af349f1c5eee4d5bed4c6eeb76a7c279bb2b61
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 44128b90..895a841 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2463,6 +2463,7 @@
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
index 57bdd3a..a3f6e8e 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
@@ -43,6 +43,7 @@
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -207,6 +208,10 @@
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
HloInstruction* instr, int64 depth);
+ // Checks that the space-to-batched shape has not rendered the new spatial
+ // dimension to be smaller than the window's size.
+ bool IsSpaceToBatchedSpaceSizeSuitable(HloInstruction* instr);
+
private:
// Current HloComputation instance the ConvolutionVisitor is traversing.
HloComputation* computation_;
@@ -803,7 +808,13 @@
auto win_dims =
consumer->window().dimensions(get_chosen_spatial_dim(consumer));
const int64 rhs_dilation = win_dims.window_dilation();
+ const int64 lhs_dilation = win_dims.base_dilation();
+ // LHS dilations are supported by PropagateOnConv, and not by
+ // PropagateOnBackpropFilterConv.
+ if (lhs_dilation != 1) {
+ return false;
+ }
// If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
// batched for propagating on backprop convolutions.
if (!last_try || rhs_dilation == 1) {
@@ -928,6 +939,11 @@
return false;
}
}
+
+ // Make sure the post space-to-batch dim size is larger than window size.
+ if (consumer->opcode() == HloOpcode::kReduceWindow) {
+ return IsSpaceToBatchedSpaceSizeSuitable(consumer);
+ }
}
if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
@@ -989,8 +1005,8 @@
"mismatch";
return false;
}
- VLOG(1) << "Can propagate through select and scatter";
- return true;
+
+ return IsSpaceToBatchedSpaceSizeSuitable(consumer);
}
return true;
}
@@ -2588,6 +2604,28 @@
return nullptr;
}
+bool ConvolutionVisitor::IsSpaceToBatchedSpaceSizeSuitable(
+ HloInstruction* instr) {
+ CHECK(instr->opcode() == HloOpcode::kSelectAndScatter ||
+ instr->opcode() == HloOpcode::kReduceWindow);
+ auto old_producer = instr->mutable_operand(0);
+
+ auto dim_map_val_op = instr_to_dim_map_[old_producer];
+ const int64 old_space_dim = dim_map_val_op.second;
+ auto first_operand = old_to_new_instrs_[old_producer];
+ auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
+ const int64 new_space_dim =
+ DimLookUp(permute_dims_first_operand, old_space_dim);
+
+ const int64 window_size = instr->window().dimensions(old_space_dim).size();
+
+ if (first_operand->shape().dimensions(new_space_dim) < window_size) {
+ return false;
+ }
+
+ return true;
+}
+
ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
auto activations = convolution->mutable_operand(0);
@@ -2669,6 +2707,11 @@
VLOG(1) << "spatial size " << c.spatial_size;
+ // A very primitive cost model to thwart propagations on tiny shapes.
+ if (c.spatial_size < 2 * kNumSplits) {
+ return Status::OK();
+ }
+
auto original_conv = convolution;
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
@@ -2695,9 +2738,14 @@
// Take into account the stride of the reduce window while choosing the
// spatial_split_size. This will guarantee propagation through reduce
// windows.
- const int64 win_stride = reduce_window_or_select_and_scatter->window()
- .dimensions(output_spatial_dim)
- .stride();
+ const int64 win_stride =
+ std::max(reduce_window_or_select_and_scatter->window()
+ .dimensions(output_spatial_dim)
+ .stride(),
+ static_cast<int64>(1));
+ CHECK_NE(win_stride, 0)
+ << "Bad op " << reduce_window_or_select_and_scatter->ToString();
+ CHECK_NE(c.stride, 0) << "Bad op " << convolution->ToString();
while ((spatial_split_size / c.stride) % win_stride != 0) {
spatial_split_size += c.stride;
}
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
index 8921d98..96cfe55 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
@@ -118,9 +118,9 @@
HloModule module
ENTRY computation {
- %p2 = bf16[1,7,7,128]{3,0,2,1} parameter(0)
+ %p2 = bf16[1,28,28,128]{3,0,2,1} parameter(0)
%p3 = bf16[1,1,512,128]{3,2,1,0} parameter(1)
- ROOT %c = bf16[1,14,14,512]{3,0,2,1} convolution(%p2, %p3),
+ ROOT %c = bf16[1,56,56,512]{3,0,2,1} convolution(%p2, %p3),
window={size=1x1 pad=0_1x0_1 lhs_dilate=2x2 rhs_reversal=1x1},
dim_labels=b01f_01oi->b01f
}