[XLA] Use dynamism inference to infer dynamic dimensions for reshape.
- Introduce dynamism inference function in xla builder, which tells if a value is dynamic or static.
- Use dynamism inference to infer whether an input to reshape's dimensions is dynamic.
- This removes the "-1" hack I made before in the bridge, makes the code cleaner. Plus it can support more complex cases dynamic reshape when the dimension comes from a series of transformations.

PiperOrigin-RevId: 325532056
Change-Id: Icc5bad39a857be77537e4736dd6863b833e2fe9d
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index bf9a915..a85ba54 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -109,27 +109,33 @@
     VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
             << shape.DebugString() << ", unknown_index=" << unknown_index;
 
-    shape_input.clear();
-    // Run get input again, this time with dynamic dimension represented as
-    // "-1"
-    ctx->set_dynamic_dimension_is_minus_one(true);
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
-
     int dynamic_dimension = -1;
-
-    for (int d = 0; d < num_dims; ++d) {
-      const int32 size = shape_input[d];
-      if (size == -1) {
-        if (dynamic_dimension == -1) {
+    if (ctx->InputXlaShape(0)->is_dynamic()) {
+      std::vector<bool> dynamic_dims;
+      OP_REQUIRES_OK(ctx,
+                     ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
+      for (int d = 0; d < num_dims; ++d) {
+        const bool dim_is_dynamic = dynamic_dims[d];
+        if (dim_is_dynamic) {
           dynamic_dimension = d;
-        } else {
-          if (unknown_index != d) {
-            dynamic_dimension = d;
-          }
         }
       }
-    }
 
+      // When reshaping from dynamic dimension, unkwown index is considered
+      // dynamic. E.g.,
+      //   [<=10]
+      //     |
+      // Reshape
+      //     |
+      //   [2, -1]
+      // The second dimension is dynamic.
+      if (dynamic_dimension == -1) {
+        dynamic_dimension = unknown_index;
+      }
+      VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
+              << xla::VectorString(shape.dim_sizes())
+              << ", dynamic_dim=" << dynamic_dimension;
+    }
     // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
     // in XLA to know which output dimension is dynamic.
     ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 34e108b..f0cc8d2 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -101,6 +101,48 @@
   });
 }
 
+xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
+    xla::Client* client) const {
+  switch (kind()) {
+    case Kind::kConstant: {
+      // Constant values are considered static.
+      Tensor constant_false(DT_BOOL, constant_value().shape());
+      auto flat = constant_false.flat<bool>();
+      for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
+      return constant_false;
+    }
+    case Kind::kXlaOp:
+      break;
+    case Kind::kTensorList:
+      TF_FALLTHROUGH_INTENDED;
+    case Kind::kResource:
+      TF_FALLTHROUGH_INTENDED;
+    case Kind::kInvalid:
+      return errors::InvalidArgument(
+          "ResolveDynamism called on unsupported XlaExpression: ",
+          HumanString());
+  }
+
+  if (!client)
+    return errors::InvalidArgument("client is required to resolve constant");
+
+  TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
+                      handle().builder()->BuildDynamicInferenceGraph(handle()));
+
+  TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+
+  // The XLA layout is specified minor to major, and TensorFlow uses a major to
+  // minor order.
+  std::vector<int64> layout_indices(shape.dims());
+  std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+  xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
+  TF_ASSIGN_OR_RETURN(xla::Literal literal,
+                      client->ComputeConstant(constant_graph, &layout));
+  Tensor tensor(DT_BOOL);
+  TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
+  return tensor;
+}
+
 xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
     xla::Client* client, bool dynamic_dimension_is_minus_one) const {
   switch (kind()) {
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 3010964..3546368 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -99,6 +99,10 @@
   xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
       xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
 
+  // ResolveDynamism computes where a value inside this op is dynamic or can be
+  // inferred at compile time.
+  xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
+
   // Returns the shape of the tensor.
   // The shape of a resource is the shape of a resource handle (i.e., a scalar),
   // not the shape of the resource's value.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 735a6c7..0753754 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -243,6 +243,48 @@
   return LiteralToFloat64Scalar(literal, out);
 }
 
+static Status LiteralToPredVector(const xla::LiteralSlice& literal,
+                                  std::vector<bool>* out) {
+  if (literal.shape().rank() != 1) {
+    return errors::InvalidArgument("value is not 1D, rank: ",
+                                   literal.shape().rank());
+  }
+  int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
+  if (literal.shape().element_type() != xla::PRED) {
+    return errors::InvalidArgument("value is not PRED");
+  }
+  for (int64 i = 0; i < size; ++i) {
+    out->push_back(literal.Get<bool>({i}));
+  }
+  return Status::OK();
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
+    int index, std::vector<bool>* out) {
+  xla::Literal literal;
+  XlaExpression e = InputExpression(index);
+  auto* client = compiler() ? compiler()->client() : nullptr;
+  xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
+  if (!dynamism_or_status.ok()) {
+    Status status = dynamism_or_status.status();
+    errors::AppendToMessage(&status, "while evaluating input dynamism", index,
+                            " of ", context_->op_kernel().type_string());
+    return status;
+  }
+  Tensor dynamism = dynamism_or_status.ValueOrDie();
+
+  Tensor temp(dynamism.dtype());
+  TensorShape tensor_shape({InputShape(index).num_elements()});
+  if (!temp.CopyFrom(dynamism, tensor_shape)) {
+    return errors::InvalidArgument(
+        context_->op_kernel().name(), " input ", index, " has shape ",
+        dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
+  }
+
+  TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
+  return LiteralToPredVector(literal, out);
+}
+
 // Converts an int32 or int64 1D literal to an int64 vector.
 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
                                    std::vector<int64>* out) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3cf51e6..75c3e60 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -116,6 +116,9 @@
   // returns a one-element list.
   Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
                    std::vector<TensorShape>* shapes);
+  // Evaluates input and returns their dynamism vector in a vector of
+  // predicates.
+  Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
 
   // Helper methods for constant inputs.
 
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 484fb0a..8de8216 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -32,6 +32,7 @@
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -39,6 +40,7 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/errors.h"
 
 namespace xla {
 
@@ -71,6 +73,52 @@
   entry->set_id(id);
   entry->set_name(GetFullName(base_name, separator, id));
 }
+
+ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
+  return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
+}
+
+HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
+                                              bool pred) {
+  HloInstructionProto const_instr;
+  Literal literal = LiteralUtil::CreateR0(pred);
+  Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
+  *const_instr.mutable_shape() = shape.ToProto();
+  *const_instr.mutable_literal() = literal_broadcast.ToProto();
+  *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
+  const_instr.set_id(id);
+  return const_instr;
+}
+
+// Converts a HloComputation into ReducerOr with predicate types.
+HloComputationProto CreateReduceOr(int64 reducer_id,
+                                   HloComputationProto* original_reducer) {
+  HloComputationProto reducer;
+  SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
+  std::vector<int64> operands_id;
+  for (auto& inst : original_reducer->instructions()) {
+    // Copy params.
+    if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
+        HloOpcode::kParameter) {
+      HloInstructionProto* new_param = reducer.add_instructions();
+      *new_param = inst;
+      *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+      operands_id.push_back(inst.id());
+    }
+    if (inst.id() == original_reducer->root_id()) {
+      HloInstructionProto* new_root = reducer.add_instructions();
+      *new_root = inst;
+      *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+      *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+      new_root->clear_operand_ids();
+      for (int64 operand_id : operands_id) {
+        new_root->add_operand_ids(operand_id);
+      }
+      reducer.set_root_id(inst.id());
+    }
+  }
+  return reducer;
+}
 }  // namespace
 
 namespace internal {
@@ -2842,6 +2890,196 @@
   return is_constant;
 }
 
+StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
+  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
+                      LookUpInstruction(root_op));
+
+  HloComputationProto entry;
+  SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
+                    GetNextId());
+  ProgramShapeProto* program_shape = entry.mutable_program_shape();
+  *program_shape->mutable_result() =
+      ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
+
+  std::set<int64> seen;
+  struct WorkItem {
+    explicit WorkItem(int64 handle, bool need_rewrite)
+        : handle(handle), need_rewrite(need_rewrite) {}
+    int64 handle;
+    // If need_rewrite is true, the instruction will be copied and rewrite into
+    // a pred instruction indicating if each value is dynamic. If need_rewrite
+    // is false, simply copy the instruction to the output graph.
+    // E.g.,
+    // For select(P, A, B), we need to rewrite A and B into predicates, but
+    // don't need to rewrite P.
+    bool need_rewrite;
+  };
+  std::queue<WorkItem> worklist;
+  worklist.push(WorkItem(root->id(), true));
+  entry.set_root_id(root->id());
+  std::vector<HloComputationProto> called_computatons;
+  // Rewritre instruction with id "from" into the new graph.
+  // Returns more work items that need to finish.
+  auto rewrite_instruction =
+      [&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
+    // Rewrite the instruction with following rules:
+    // - Unary ops: Convert into bitcast (identity) with type Pred.
+    // - Binary ops: Convert into binary or.
+    // - Select: Convert into binary or with its two data operands.
+    // - Concat / Tuple/ GTE / Bitcast: Copy.
+    // - Param: Convert to constant True.
+    // - GetDimensionSize: Convert to constant True if dimension is dynamic,
+    // contant False if dimension is static.
+    // - Reduce: Convert to reduce or.
+    // - Constant: Convert to constant False.
+    // - Other ops: Not supported.
+    // Create the instruction for the new handle.
+    TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+                        LookUpInstructionByHandle(from));
+
+    TF_ASSIGN_OR_RETURN(HloOpcode opcode,
+                        StringToHloOpcode(instr_proto->opcode()));
+    std::vector<WorkItem> operands_todo;
+    auto* new_instr = entry.add_instructions();
+    *new_instr = *instr_proto;
+    for (auto operand_id : new_instr->operand_ids()) {
+      operands_todo.emplace_back(operand_id, need_rewrite);
+    }
+
+    if (!need_rewrite) {
+      *new_instr->mutable_name() =
+          GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
+      return operands_todo;
+    }
+    *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
+    Shape new_shape(new_instr->shape());
+    switch (opcode) {
+      case HloOpcode::kAbs:
+      case HloOpcode::kRoundNearestAfz:
+      case HloOpcode::kBitcast:
+      case HloOpcode::kCeil:
+      case HloOpcode::kCollectivePermuteDone:
+      case HloOpcode::kCos:
+      case HloOpcode::kClz:
+      case HloOpcode::kExp:
+      case HloOpcode::kExpm1:
+      case HloOpcode::kFloor:
+      case HloOpcode::kImag:
+      case HloOpcode::kIsFinite:
+      case HloOpcode::kLog:
+      case HloOpcode::kLog1p:
+      case HloOpcode::kNot:
+      case HloOpcode::kNegate:
+      case HloOpcode::kPopulationCount:
+      case HloOpcode::kReal:
+      case HloOpcode::kRsqrt:
+      case HloOpcode::kLogistic:
+      case HloOpcode::kSign:
+      case HloOpcode::kSin:
+      case HloOpcode::kConvert:
+      case HloOpcode::kSqrt:
+      case HloOpcode::kCbrt:
+      case HloOpcode::kTanh:
+        CHECK_EQ(instr_proto->operand_ids_size(), 1);
+        *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
+        break;
+      case HloOpcode::kAdd:
+      case HloOpcode::kAtan2:
+      case HloOpcode::kDivide:
+      case HloOpcode::kComplex:
+      case HloOpcode::kMaximum:
+      case HloOpcode::kMinimum:
+      case HloOpcode::kMultiply:
+      case HloOpcode::kPower:
+      case HloOpcode::kRemainder:
+      case HloOpcode::kSubtract:
+      case HloOpcode::kCompare:
+      case HloOpcode::kAnd:
+      case HloOpcode::kOr:
+      case HloOpcode::kXor:
+      case HloOpcode::kShiftLeft:
+      case HloOpcode::kShiftRightArithmetic:
+      case HloOpcode::kShiftRightLogical:
+        CHECK_EQ(instr_proto->operand_ids_size(), 2);
+        *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+        break;
+      case HloOpcode::kSelect:
+        operands_todo[0].need_rewrite = false;
+        break;
+      case HloOpcode::kGather:
+        operands_todo[1].need_rewrite = false;
+        break;
+      case HloOpcode::kReduce: {
+        int64 reducer_id = new_instr->called_computation_ids(0);
+        called_computatons.push_back(
+            CreateReduceOr(reducer_id, &embedded_[reducer_id]));
+        break;
+      }
+      case HloOpcode::kTuple:
+      case HloOpcode::kTranspose:
+      case HloOpcode::kGetTupleElement:
+      case HloOpcode::kSlice:
+      case HloOpcode::kBroadcast:
+      case HloOpcode::kConcatenate:
+      case HloOpcode::kReshape:
+        break;
+      case HloOpcode::kGetDimensionSize: {
+        int64 dimension = instr_proto->dimensions(0);
+        int64 operand_handle = instr_proto->operand_ids(0);
+        TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
+                            LookUpInstructionByHandle(operand_handle));
+
+        *new_instr = CreateConstantInstruction(
+            from, new_shape,
+            operand_proto->shape().is_dynamic_dimension(dimension));
+        operands_todo.clear();
+        break;
+      }
+      case HloOpcode::kConstant:
+        *new_instr = CreateConstantInstruction(from, new_shape, false);
+        break;
+      case HloOpcode::kParameter:
+        *new_instr = CreateConstantInstruction(from, new_shape, true);
+        break;
+      default:
+        return InvalidArgument("Dynamic inferencing %s is not supported",
+                               instr_proto->DebugString());
+    }
+    *new_instr->mutable_name() =
+        GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
+    return operands_todo;
+  };
+
+  while (!worklist.empty()) {
+    WorkItem item = worklist.front();
+    worklist.pop();
+    if (!seen.insert(item.handle).second) {
+      continue;
+    }
+    TF_ASSIGN_OR_RETURN(auto todos,
+                        rewrite_instruction(item.handle, item.need_rewrite));
+    for (WorkItem& todo : todos) {
+      worklist.push(todo);
+    }
+  }
+  absl::c_sort(*entry.mutable_instructions(),
+               [](const HloInstructionProto& p1,
+                  const HloInstructionProto& p2) { return p1.id() < p2.id(); });
+  XlaComputation computation(entry.id());
+  HloModuleProto* module = computation.mutable_proto();
+  module->set_name(entry.name());
+  module->set_id(entry.id());
+  module->set_entry_computation_name(entry.name());
+  module->set_entry_computation_id(entry.id());
+  *module->mutable_host_program_shape() = *program_shape;
+  for (auto& called_comp : called_computatons) {
+    *module->add_computations() = called_comp;
+  }
+  *module->add_computations() = std::move(entry);
+  XLA_VLOG_LINES(3, module->DebugString());
+  return std::move(computation);
+}
+
 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
     XlaOp root_op, bool dynamic_dimension_is_minus_one) {
   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index aa5074d..6753b6d 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -278,6 +278,31 @@
   StatusOr<XlaComputation> BuildConstantSubGraph(
       XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
 
+  // Similar to BuildConstantSubGraph, but with root element type changed to
+  // boolean. A true value in the root indicates that the value is dynamic while
+  // false value indicates that the value is a constant. This will copy the
+  // needed ops/computations to the subgraph.
+  //
+  // E.g.,
+  // Compuptation {
+  //   a = 3
+  //   b = param(0)
+  //   ROOT Tuple(a + b, a + 1, b + 1)
+  // }
+  // Calling BuildDynamicInferenceGraph on root will produce the following
+  // graph:
+  //
+  // Compuptation {
+  //   a = False
+  //   b = True
+  //   ROOT Tuple(a | b, a, b)
+  // }
+  //
+  // The result, which is (True, False, True) after evaluation, can be
+  // interpreted as "First element is dynamic; Second element is static; Third
+  // element is dynamic".
+  StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
+
   // Returns the first error that was encountered while building the
   // computation. When an error is encountered, by default we return a vacuous
   // XlaOp and inform the user of the error that occurred while
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 2f24568..36429d3 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -805,7 +805,8 @@
         }
 
         if (input_dim_size > output_dim_size) {
-          TF_RET_CHECK(input_dim_size % output_dim_size == 0);
+          TF_RET_CHECK(input_dim_size % output_dim_size == 0)
+              << reshape->ToString();
           const int64 divisor = input_dim_size / output_dim_size;
           HloInstruction* divisor_hlo =
               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 02fcaaf..0833919 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -783,9 +783,18 @@
 
 /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
                                                 PrimitiveType type) {
-  Shape new_shape = original;
-  new_shape.set_element_type(type);
-  return new_shape;
+  if (original.IsTuple()) {
+    std::vector<Shape> new_operands;
+    new_operands.reserve(original.tuple_shapes_size());
+    for (const Shape& operand : original.tuple_shapes()) {
+      new_operands.push_back(ChangeElementType(operand, type));
+    }
+    return MakeTupleShape(new_operands);
+  } else {
+    Shape new_shape = original;
+    new_shape.set_element_type(type);
+    return new_shape;
+  }
 }
 
 /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 927f9d1..17444c0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2089,6 +2089,31 @@
 )
 
 xla_test(
+    name = "dynamism_inference_test",
+    srcs = ["dynamism_inference_test.cc"],
+    deps = [
+        ":test_macros_header",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:client_library",
+        "//tensorflow/compiler/xla/client:global_data",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/client/lib:prng",
+        "//tensorflow/compiler/xla/tests:literal_test_util",
+        "//tensorflow/compiler/xla/tests:test_utils",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+xla_test(
     name = "compute_constant_test",
     srcs = ["compute_constant_test.cc"],
     deps = [
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
new file mode 100644
index 0000000..ba4092d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/match.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/prng.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// An enumerator for the client types that we want to iterate over in
+// the various tests.
+enum class ClientType { kLocal, kCompileOnly };
+ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
+
+class DynamismInferenceTest : public ::testing::Test {
+ public:
+  explicit DynamismInferenceTest(se::Platform* platform = nullptr)
+      : platform_(platform) {}
+
+  string TestName() const {
+    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+  }
+
+  Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
+    if (client_type == ClientType::kLocal) {
+      StatusOr<Client*> result =
+          ClientLibrary::GetOrCreateLocalClient(platform);
+      TF_CHECK_OK(result.status())
+          << "could not create LocalClient for testing";
+      return result.ValueOrDie();
+    } else if (client_type == ClientType::kCompileOnly) {
+      StatusOr<Client*> result =
+          ClientLibrary::GetOrCreateCompileOnlyClient(platform);
+      TF_CHECK_OK(result.status())
+          << "could not create CompileOnlyClient for testing";
+      return result.ValueOrDie();
+    }
+    LOG(FATAL) << "invalid client_type value";
+  }
+
+  StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
+                                           XlaBuilder* builder,
+                                           Layout* output_layout = nullptr) {
+    TF_ASSIGN_OR_RETURN(auto subgraph,
+                        builder->BuildDynamicInferenceGraph(operand));
+    TF_ASSIGN_OR_RETURN(auto computed,
+                        client->ComputeConstant(subgraph, output_layout));
+    return std::move(computed);
+  }
+
+  StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
+                                       XlaBuilder* builder,
+                                       ShapeIndex index = {}) {
+    TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
+                                                             builder, nullptr));
+    return literal.Get<bool>({}, index);
+  }
+
+  se::Platform* platform_;
+};
+
+TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto computation = ConstantR0<int32>(&b, 42);
+
+    auto value = ComputeDynamismScalar(client, computation, &b);
+    ASSERT_TRUE(value.ok()) << value.status();
+    // A constant is not dynamic.
+    EXPECT_EQ(value.ValueOrDie(), false);
+  }
+}
+
+TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+    auto tuple = Tuple(&b, {c, p});
+    auto gte0 = GetTupleElement(tuple, 0);
+    auto gte1 = GetTupleElement(tuple, 1);
+    auto tuple_2 = Tuple(&b, {gte0, gte1});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+              false);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+              true);
+  }
+}
+
+TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+    auto concat = ConcatScalars(&b, {c, p});
+    auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
+    auto reshape0 = Reshape(slice0, {});
+    auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
+    auto reshape1 = Reshape(slice1, {});
+    auto tuple_2 = Tuple(&b, {reshape0, reshape1});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+              false);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+              true);
+  }
+}
+
+TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+    auto value = ComputeDynamismScalar(client, computation, &b);
+    ASSERT_TRUE(value.ok()) << value.status();
+    // A parameter is considered dynamic.
+    EXPECT_EQ(value.ValueOrDie(), true);
+  }
+}
+
+TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+    auto neg0 = Neg(c);
+    auto neg1 = Neg(p);
+    auto tuple_2 = Tuple(&b, {neg0, neg1});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+              false);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+              true);
+  }
+}
+
+TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+    // Static value + static value = static
+    auto add1 = Add(c, c);
+    // Dynamic value + dynamic value = dynamic
+    auto add2 = Add(p, c);
+    auto tuple_2 = Tuple(&b, {add1, add2});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+              false);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+              true);
+  }
+}
+
+TEST_F(DynamismInferenceTest, GetDimensionSize) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    // param = Param([<=2, 3])
+    // get_dimension_size(param, 0) is dynamic
+    // get_dimension_size(param, 1) is static
+    auto p =
+        Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
+
+    auto gds0 = GetDimensionSize(p, 0);
+    auto gds1 = GetDimensionSize(p, 1);
+    auto tuple_2 = Tuple(&b, {gds0, gds1});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+              true);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+              false);
+  }
+}
+
+}  // namespace
+}  // namespace xla