Improve BroadcastTo() to also support trivially broadcasting 1 -> n (as well as n -> n). Then, remove special casing in tile_ops.cc.

In all cases that would be triggered by the special case, it now will build a BroadcastInDim() such that broadcast_shape == output_dims and the xla::Reshape will not be triggered. Thus, the special case is not needed.

PiperOrigin-RevId: 273015433
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index e1c764f..e8804ca 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -79,29 +79,6 @@
       return;
     }
 
-    bool can_tile_with_implicit_broadcast = true;
-    for (int i = 0; i < input_dims; ++i) {
-      int64 multiple = multiples[i];
-      // If the multiple and input dimension are not 1, then tile cannot be
-      // implemented with a single hlo broadcast.
-      if (multiple != 1 && input_shape.dim_size(i) != 1) {
-        can_tile_with_implicit_broadcast = false;
-      }
-    }
-
-    if (can_tile_with_implicit_broadcast) {
-      // Create a constant Zero the size of the output shape to leverage binary
-      // operation broadcast semantics.
-      auto broadcasted_zero = xla::Broadcast(
-          XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_dims);
-      if (ctx->input_type(0) == DT_BOOL) {
-        ctx->SetOutput(0, xla::Or(broadcasted_zero, input));
-      } else {
-        ctx->SetOutput(0, xla::Add(broadcasted_zero, input));
-      }
-      return;
-    }
-
     auto result = BroadcastTo(ctx->Input("input"), output_dims);
     OP_REQUIRES_OK(ctx, result.status());
     ctx->SetOutput(0, result.ValueOrDie());
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc
index a0789f9..7251a2e 100644
--- a/tensorflow/compiler/tf2xla/lib/broadcast.cc
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc
@@ -61,7 +61,7 @@
       }
 
       broadcast_dims.push_back(broadcast_shape.size());
-      if (*output_it == *input_it) {
+      if (*output_it == *input_it || *input_it == 1) {
         broadcast_shape.push_back(*output_it);
       } else if (*output_it != *input_it) {
         // Add dimensions [I, O/I], which we will later flatten to just