Fix correctness corner cases in space-to-batch transformation

PiperOrigin-RevId: 358218872
Change-Id: Ic2af349f1c5eee4d5bed4c6eeb76a7c279bb2b61
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 44128b90..895a841 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2463,6 +2463,7 @@
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
index 57bdd3a..a3f6e8e 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/compiler/xla/service/shape_inference.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -207,6 +208,10 @@
   HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
       HloInstruction* instr, int64 depth);
 
+  // Checks that the space-to-batched shape has not rendered the new spatial
+  // dimension to be smaller than the window's size.
+  bool IsSpaceToBatchedSpaceSizeSuitable(HloInstruction* instr);
+
  private:
   // Current HloComputation instance the ConvolutionVisitor is traversing.
   HloComputation* computation_;
@@ -803,7 +808,13 @@
     auto win_dims =
         consumer->window().dimensions(get_chosen_spatial_dim(consumer));
     const int64 rhs_dilation = win_dims.window_dilation();
+    const int64 lhs_dilation = win_dims.base_dilation();
 
+    // LHS dilations are supported by PropagateOnConv, and not by
+    // PropagateOnBackpropFilterConv.
+    if (lhs_dilation != 1) {
+      return false;
+    }
     // If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
     // batched for propagating on backprop convolutions.
     if (!last_try || rhs_dilation == 1) {
@@ -928,6 +939,11 @@
         return false;
       }
     }
+
+    // Make sure the post space-to-batch dim size is larger than window size.
+    if (consumer->opcode() == HloOpcode::kReduceWindow) {
+      return IsSpaceToBatchedSpaceSizeSuitable(consumer);
+    }
   }
 
   if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
@@ -989,8 +1005,8 @@
                  "mismatch";
       return false;
     }
-    VLOG(1) << "Can propagate through select and scatter";
-    return true;
+
+    return IsSpaceToBatchedSpaceSizeSuitable(consumer);
   }
   return true;
 }
@@ -2588,6 +2604,28 @@
   return nullptr;
 }
 
+bool ConvolutionVisitor::IsSpaceToBatchedSpaceSizeSuitable(
+    HloInstruction* instr) {
+  CHECK(instr->opcode() == HloOpcode::kSelectAndScatter ||
+        instr->opcode() == HloOpcode::kReduceWindow);
+  auto old_producer = instr->mutable_operand(0);
+
+  auto dim_map_val_op = instr_to_dim_map_[old_producer];
+  const int64 old_space_dim = dim_map_val_op.second;
+  auto first_operand = old_to_new_instrs_[old_producer];
+  auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
+  const int64 new_space_dim =
+      DimLookUp(permute_dims_first_operand, old_space_dim);
+
+  const int64 window_size = instr->window().dimensions(old_space_dim).size();
+
+  if (first_operand->shape().dimensions(new_space_dim) < window_size) {
+    return false;
+  }
+
+  return true;
+}
+
 ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
     HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
   auto activations = convolution->mutable_operand(0);
@@ -2669,6 +2707,11 @@
 
   VLOG(1) << "spatial size " << c.spatial_size;
 
+  // A very primitive cost model to thwart propagations on tiny shapes.
+  if (c.spatial_size < 2 * kNumSplits) {
+    return Status::OK();
+  }
+
   auto original_conv = convolution;
 
   const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
@@ -2695,9 +2738,14 @@
     // Take into account the stride of the reduce window while choosing the
     // spatial_split_size. This will guarantee propagation through reduce
     // windows.
-    const int64 win_stride = reduce_window_or_select_and_scatter->window()
-                                 .dimensions(output_spatial_dim)
-                                 .stride();
+    const int64 win_stride =
+        std::max(reduce_window_or_select_and_scatter->window()
+                     .dimensions(output_spatial_dim)
+                     .stride(),
+                 static_cast<int64>(1));
+    CHECK_NE(win_stride, 0)
+        << "Bad op " << reduce_window_or_select_and_scatter->ToString();
+    CHECK_NE(c.stride, 0) << "Bad op " << convolution->ToString();
     while ((spatial_split_size / c.stride) % win_stride != 0) {
       spatial_split_size += c.stride;
     }
diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
index 8921d98..96cfe55 100644
--- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
+++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc
@@ -118,9 +118,9 @@
   
   HloModule module
 ENTRY computation {
-  %p2 = bf16[1,7,7,128]{3,0,2,1} parameter(0)
+  %p2 = bf16[1,28,28,128]{3,0,2,1} parameter(0)
   %p3 = bf16[1,1,512,128]{3,2,1,0} parameter(1)
-  ROOT %c = bf16[1,14,14,512]{3,0,2,1} convolution(%p2, %p3),
+  ROOT %c = bf16[1,56,56,512]{3,0,2,1} convolution(%p2, %p3),
     window={size=1x1 pad=0_1x0_1 lhs_dilate=2x2 rhs_reversal=1x1},
     dim_labels=b01f_01oi->b01f
 }