[Resubmit]Value inference on upper- and lower- bounds.

- This cl allows xla to infer both the upper and lower bounds of common ops.
- The client can trigger the inference through xla op kernel API.
- Useful to infer bounds of subtract and divide case:
T = # some_dynamic_tensor
tf.range(x - tf.size(T)) # Bound of tf.range is x, instead of x-tf.size(T).
PiperOrigin-RevId: 367139035
Change-Id: I0fe76aa631af287a532c1b99a5e3d4452fbf711c
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index edc89aa..4f93d3e 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -362,6 +362,7 @@
         "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/mlir:array_container_utils",
         "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
+        "//tensorflow/compiler/xla/client:value_inference",
         "//tensorflow/compiler/xla:protobuf_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index dfbad70..04b97ce 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -84,8 +84,10 @@
                 errors::InvalidArgument("delta must be a scalar, not shape ",
                                         delta_in_shape.DebugString()));
     xla::Literal start, limit, delta;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
-    OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
+    OP_REQUIRES_OK(ctx, ctx->ConstantInput(
+                            0, &start, xla::ValueInferenceMode::kLowerBound));
+    OP_REQUIRES_OK(ctx, ctx->ConstantInput(
+                            1, &limit, xla::ValueInferenceMode::kUpperBound));
     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
 
     DataType type = input_type(0);
@@ -108,23 +110,29 @@
                                          DataTypeString(type));
     }
     OP_REQUIRES_OK(ctx, output.status());
+    bool start_is_dynamic = false;
+    OP_REQUIRES_OK(ctx,
+                   ctx->ResolveInputDynamismIntoPred(0, &start_is_dynamic));
+    bool limit_is_dynamic = false;
+    OP_REQUIRES_OK(ctx,
+                   ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
 
-    if (type == DT_INT32 || type == DT_INT64) {
-      bool limit_is_dynamic = false;
-      OP_REQUIRES_OK(ctx,
-                     ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
-      if (type == DT_INT32) {
-        if (limit_is_dynamic) {
-          output = xla::SetDimensionSize(output.ValueOrDie(), ctx->Input(1), 0);
-        }
+    if (start_is_dynamic || limit_is_dynamic) {
+      xla::XlaOp delta = ctx->Input(2);
+      xla::XlaOp limit = ctx->Input(1);
+      xla::XlaOp start = ctx->Input(0);
+      if (type == DT_INT32 || type == DT_INT64) {
+        auto dynamic_size =
+            ((xla::Abs(limit - start) + xla::Abs(delta) -
+              xla::One(ctx->builder(), ctx->input_xla_type(0))) /
+             xla::Abs(delta));
+        output = xla::SetDimensionSize(output.ValueOrDie(), dynamic_size, 0);
       } else {
-        if (limit_is_dynamic) {
-          output = xla::SetDimensionSize(
-              output.ValueOrDie(),
-              xla::ConvertElementType(ctx->Input(1), xla::S32), 0);
-        }
+        auto dynamic_size = (xla::Ceil(xla::Abs((limit - start) / delta)));
+        output = xla::SetDimensionSize(output.ValueOrDie(), dynamic_size, 0);
       }
     }
+
     ctx->SetOutput(0, output.ValueOrDie());
   }
 };
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 099d546..87fdc63 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -152,7 +152,8 @@
 }
 
 xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
-    xla::Client* client, bool dynamic_dimension_is_minus_one) const {
+    xla::Client* client, bool dynamic_dimension_is_minus_one,
+    xla::ValueInferenceMode mode) const {
   switch (kind()) {
     case Kind::kConstant:
     case Kind::kResource:
@@ -165,6 +166,22 @@
       return errors::InvalidArgument(
           "ResolveConstant called on XlaExpression: ", HumanString());
   }
+  TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+  if (mode == xla::ValueInferenceMode::kLowerBound ||
+      mode == xla::ValueInferenceMode::kUpperBound) {
+    std::vector<int64> layout_indices(shape.dims());
+    std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+    xla::ValueInference value_inference(handle().builder());
+    TF_ASSIGN_OR_RETURN(xla::OptionalLiteral literal,
+                        value_inference.AnalyzeConstant(handle(), mode));
+    if (!literal.GetValue().has_value()) {
+      return {absl::nullopt};
+    }
+    Tensor tensor;
+    TF_RETURN_IF_ERROR(
+        LiteralToHostTensor(literal.GetValue().value(), dtype(), &tensor));
+    return {tensor};
+  }
 
   TF_ASSIGN_OR_RETURN(bool is_constant,
                       handle().builder()->IsConstant(handle()));
@@ -179,8 +196,6 @@
                       handle().builder()->BuildConstantSubGraph(
                           handle(), dynamic_dimension_is_minus_one));
 
-  TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
-
   // The XLA layout is specified minor to major, and TensorFlow uses a major to
   // minor order.
   std::vector<int64> layout_indices(shape.dims());
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 408afef..69f367d 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -19,6 +19,7 @@
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/tf2xla/xla_resource.h"
 #include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -106,16 +107,17 @@
   // Returns a human-readable summary of the expression.
   string HumanString() const;
 
-  // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
+  // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns
   // an erroneous XlaOp if the expression is not a constant or an expression.
   xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
 
-  // If a kXlaOp or kConstant expression can be resolved to a compile-time
+  // If a kXlaOp or kValue expression can be resolved to a compile-time
   // constant, returns the value as a host-memory Tensor. Returns an empty
   // optional if it cannot be resolved. Returns an error if passed a resource
   // expression.
   xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
-      xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
+      xla::Client* client, bool dynamic_dimension_is_minus_one = false,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue) const;
 
   // ResolveDynamism computes where a value inside this op is dynamic or can be
   // inferred at compile time.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index f2eb038..7b02df2 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -128,9 +128,11 @@
 }
 
 Status XlaOpKernelContext::ConstantInput(int index,
-                                         xla::Literal* constant_literal) {
-  return ConstantInputReshaped(
-      index, context_->input(index).shape().dim_sizes(), constant_literal);
+                                         xla::Literal* constant_literal,
+                                         xla::ValueInferenceMode mode) {
+  return ConstantInputReshaped(index,
+                               context_->input(index).shape().dim_sizes(),
+                               constant_literal, mode);
 }
 
 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
@@ -147,18 +149,19 @@
 }
 
 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
-                                         xla::Literal* constant_literal) {
+                                         xla::Literal* constant_literal,
+                                         xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
-  return ConstantInput(index, constant_literal);
+  return ConstantInput(index, constant_literal, mode);
 }
 
 Status XlaOpKernelContext::ConstantInputReshaped(
-    int index, absl::Span<const int64> new_dims,
-    xla::Literal* constant_literal) {
+    int index, absl::Span<const int64> new_dims, xla::Literal* constant_literal,
+    xla::ValueInferenceMode mode) {
   XlaExpression e = InputExpression(index);
   auto* client = compiler() ? compiler()->client() : nullptr;
   xla::StatusOr<absl::optional<Tensor>> constant_or_status =
-      e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
+      e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
   if (!constant_or_status.ok()) {
     Status status = constant_or_status.status();
     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
@@ -225,21 +228,23 @@
   return Status::OK();
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
+Status XlaOpKernelContext::ConstantInputAsIntScalar(
+    int index, int64* out, xla::ValueInferenceMode mode) {
   xla::Literal literal;
-  TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+  TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
   return LiteralToInt64Scalar(literal, out);
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
-                                                    int64* out) {
+Status XlaOpKernelContext::ConstantInputAsIntScalar(
+    absl::string_view name, int64* out, xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
-  return ConstantInputAsIntScalar(index, out);
+  return ConstantInputAsIntScalar(index, out, mode);
 }
 
-Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
+Status XlaOpKernelContext::ConstantInputAsFloatScalar(
+    int index, double* out, xla::ValueInferenceMode mode) {
   xla::Literal literal;
-  TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+  TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
   return LiteralToFloat64Scalar(literal, out);
 }
 
@@ -341,40 +346,42 @@
   return Status::OK();
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
-                                                    std::vector<int64>* out) {
+Status XlaOpKernelContext::ConstantInputAsIntVector(
+    int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
   xla::Literal literal;
-  TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+  TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
   return LiteralToInt64Vector(literal, out);
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
-                                                    std::vector<int64>* out) {
+Status XlaOpKernelContext::ConstantInputAsIntVector(
+    absl::string_view name, std::vector<int64>* out,
+    xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   return ConstantInputAsIntVector(index, out);
 }
 
 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
-    int index, std::vector<int64>* out) {
+    int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
   xla::Literal literal;
   TF_RETURN_IF_ERROR(ConstantInputReshaped(
-      index, {InputShape(index).num_elements()}, &literal));
+      index, {InputShape(index).num_elements()}, &literal, mode));
   return LiteralToInt64Vector(literal, out);
 }
 
 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
-    absl::string_view name, std::vector<int64>* out) {
+    absl::string_view name, std::vector<int64>* out,
+    xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   xla::Literal literal;
   TF_RETURN_IF_ERROR(ConstantInputReshaped(
-      index, {InputShape(index).num_elements()}, &literal));
+      index, {InputShape(index).num_elements()}, &literal, mode));
   return LiteralToInt64Vector(literal, out);
 }
 
-Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
-                                                       xla::Literal* out) {
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(
+    int index, xla::Literal* out, xla::ValueInferenceMode mode) {
   xla::Literal literal;
-  TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+  TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
   switch (literal.shape().element_type()) {
     case xla::S32: {
       *out = xla::Literal(
@@ -396,17 +403,18 @@
   }
 }
 
-Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
-                                                       xla::Literal* out) {
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(
+    absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
-  return ConstantInputAsInt64Literal(index, out);
+  return ConstantInputAsInt64Literal(index, out, mode);
 }
 
 // TODO(phawkins): validate that the dimensions form a valid shape, fail
 // gracefully if they do not.
-Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
+Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
+                                                xla::ValueInferenceMode mode) {
   xla::Literal literal;
-  TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+  TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
   std::vector<int64> dims;
   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
   *shape = TensorShape(dims);
@@ -449,13 +457,14 @@
   return Status::OK();
 }
 
-Status XlaOpKernelContext::ConstantInputList(
-    absl::string_view name, std::vector<xla::Literal>* outputs) {
+Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
+                                             std::vector<xla::Literal>* outputs,
+                                             xla::ValueInferenceMode mode) {
   int start, stop;
   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
   outputs->resize(stop - start);
   for (int i = start; i < stop; ++i) {
-    TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
+    TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 48c66b9..4cfe730 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -20,6 +20,7 @@
 #include "tensorflow/compiler/tf2xla/xla_context.h"
 #include "tensorflow/compiler/tf2xla/xla_expression.h"
 #include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -130,34 +131,58 @@
 
   // Evaluates input `index` and stores it in `*constant_literal`. If the
   // expression cannot be evaluated, e.g., because it depends on unbound
-  // parameters, returns a non-OK status.
-  Status ConstantInput(int index, xla::Literal* constant_literal);
-  Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
+  // parameters, returns a non-OK status. This function can also be used to
+  // infer constant input upper or lower bounds, by changing the `mode`
+  // parameter.
+  Status ConstantInput(
+      int index, xla::Literal* constant_literal,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+  Status ConstantInput(
+      absl::string_view name, xla::Literal* constant_literal,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant scalar int32 or int64 tensor into an int64.
-  Status ConstantInputAsIntScalar(int index, int64* out);
-  Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
+  Status ConstantInputAsIntScalar(
+      int index, int64* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+  Status ConstantInputAsIntScalar(
+      absl::string_view name, int64* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant scalar float32 or float64 tensor into a float64.
-  Status ConstantInputAsFloatScalar(int index, double* out);
+  Status ConstantInputAsFloatScalar(
+      int index, double* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
-  Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
-  Status ConstantInputAsIntVector(absl::string_view name,
-                                  std::vector<int64>* out);
+  Status ConstantInputAsIntVector(
+      int index, std::vector<int64>* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+  Status ConstantInputAsIntVector(
+      absl::string_view name, std::vector<int64>* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Reshapes and converts a constant int32 or int64 tensor into a vector of
   // int64s.
-  Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
-  Status ConstantInputReshapedToIntVector(absl::string_view name,
-                                          std::vector<int64>* out);
+  Status ConstantInputReshapedToIntVector(
+      int index, std::vector<int64>* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+  Status ConstantInputReshapedToIntVector(
+      absl::string_view name, std::vector<int64>* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
-  Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
-  Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
+  Status ConstantInputAsInt64Literal(
+      int index, xla::Literal* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+  Status ConstantInputAsInt64Literal(
+      absl::string_view name, xla::Literal* out,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
-  Status ConstantInputAsShape(int index, TensorShape* shape);
+  Status ConstantInputAsShape(
+      int index, TensorShape* shape,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
   // into a PartialTensorShape.
@@ -166,8 +191,9 @@
   // Returns the named list-valued immutable input in "list", as
   // defined in the OpDef.  If the named output is not list-valued,
   // returns a one-element list.
-  Status ConstantInputList(absl::string_view name,
-                           std::vector<xla::Literal>* literals);
+  Status ConstantInputList(
+      absl::string_view name, std::vector<xla::Literal>* outputs,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   // Returns an XlaExpression describing the value of 'index'.
   const XlaExpression& InputExpression(int index);
@@ -309,8 +335,10 @@
   // cannot be evaluated, e.g., because it depends on unbound parameters,
   // returns a non-Ok status. If InputShape(index).num_elements() !=
   // new_shape.num_elements(), returns an error status.
-  Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
-                               xla::Literal* constant_literal);
+  Status ConstantInputReshaped(
+      int index, absl::Span<const int64> new_dims,
+      xla::Literal* constant_literal,
+      xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
 
   OpKernelContext* const context_;
   bool dynamic_dimension_is_minus_one_;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 0d6e283..eae1816 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -218,13 +218,17 @@
     visibility = ["//visibility:public"],
     deps = [
         ":xla_builder",
+        "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_evaluator",
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
+        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 41c2c92..7dd2395 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -13,15 +13,24 @@
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/compiler/xla/client/value_inference.h"
+#include <utility>
+#include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/types/span.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
 #include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 namespace {
@@ -67,15 +76,13 @@
   }
 }
 
-using GetOperand = std::function<StatusOr<LiteralSlice>(int64 operand_index,
-                                                        int64 opreand_handle)>;
+
 
 // HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
 // to provide operand as literals through the get_operand function.
 struct HloProtoEvaluator {
-  explicit HloProtoEvaluator(HloInstructionProto inst, GetOperand get_operand)
+  explicit HloProtoEvaluator(HloInstructionProto inst)
       : inst(std::move(inst)),
-        get_operand(get_operand),
         module("EmptyModuleForEvaluation", HloModuleConfig()) {}
 
   // WithOpCode changes the called computation of the instruction being
@@ -103,6 +110,12 @@
     return *this;
   }
 
+  // WithOpCode changes the opcode of the instruction being evaluated.
+  HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
+    this->operands = operands;
+    return *this;
+  }
+
   StatusOr<Literal> Evaluate() {
     // Evaluate the instruction by swapping it's operands with constant
     // instructions with given literals.
@@ -110,9 +123,8 @@
     absl::flat_hash_map<int64, HloInstruction*> operand_map;
     for (int64 i = 0; i < inst.operand_ids_size(); ++i) {
       int64 operand_handle = inst.operand_ids(i);
-      TF_ASSIGN_OR_RETURN(auto literal, get_operand(i, inst.operand_ids(i)));
       std::unique_ptr<HloInstruction> operand =
-          HloInstruction::CreateConstant(literal.Clone());
+          HloInstruction::CreateConstant(operands[i].Clone());
       operand_map[operand_handle] = operand.get();
       builder.AddInstruction(std::move(operand));
     }
@@ -144,108 +156,448 @@
   }
 
   HloInstructionProto inst;
-  GetOperand get_operand;
+
   HloModule module;
+  absl::Span<Literal> operands;
   HloComputation* computation = nullptr;
   absl::optional<PrimitiveType> primitive_type = absl::nullopt;
   absl::optional<HloOpcode> opcode = absl::nullopt;
 };
+
+enum PostorderDFSNodeType {
+  // This node is about figuring out the constant value.
+  kConstantValue = 0,
+  // This node is about figuring out the constant bound.
+  kConstantUpperBound,
+  kConstantLowerBound,
+  // This node is about figuring out whether a value is dynamic.
+  kValueIsDynamic,
+  // This node is about figuring out whether a bound value is dynamic. It's
+  // similar to kValueIsDynamic, but views shape bound as static values.
+  kBoundIsDynamic,
+};
+
+// Each node in the postorder traversal tree may depend on traversing the
+// values of the node's children.
+struct PostorderDFSDep {
+  explicit PostorderDFSDep(int64 handle, PostorderDFSNodeType type)
+      : handle(handle), type(type) {}
+  int64 handle;
+  PostorderDFSNodeType type;
+};
+
+// This function represents the logic to visit a node once its dependencies
+// (operands) are all resolved.
+using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
+// Convenient specializations of Visit function for different operands.
+using Visit0D = std::function<StatusOr<Literal>()>;
+using Visit1D = std::function<StatusOr<Literal>(Literal)>;
+using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
+
+// A postorder dfs node can be visited once its dependency requests are all
+// fulfilled.
+struct PostorderDFSNode {
+  PostorderDFSNode& AddDependency(int64 handle, PostorderDFSNodeType type) {
+    dependencies.emplace_back(handle, type);
+    return *this;
+  }
+
+  PostorderDFSNode& AddVisit(const Visit& visit) {
+    this->visit = visit;
+    return *this;
+  }
+
+  PostorderDFSNode& AddVisit(const Visit0D& visit) {
+    this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
+    return *this;
+  }
+
+  PostorderDFSNode& AddVisit(const Visit1D& visit) {
+    this->visit = [visit](absl::Span<Literal> literals) {
+      return visit(std::move(literals[0]));
+    };
+    return *this;
+  }
+
+  PostorderDFSNode& AddVisit(const Visit2D& visit) {
+    this->visit = [visit](absl::Span<Literal> literals) {
+      return visit(std::move(literals[0]), std::move(literals[1]));
+    };
+    return *this;
+  }
+
+  std::vector<PostorderDFSDep> dependencies;
+  Visit visit;
+};
+
+// Convert an interger handle to HloInstructionProto.
+using HandleToInstruction = std::function<const HloInstructionProto*(int64)>;
+using HandleToComputation = std::function<const HloComputationProto*(int64)>;
+
+struct PostorderDFSVisitor {
+  PostorderDFSVisitor(HandleToInstruction handle_to_instruction,
+                        HandleToComputation handle_to_computation)
+      : handle_to_instruction(handle_to_instruction),
+        handle_to_computation(handle_to_computation) {}
+
+  StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64 handle);
+  StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64 handle);
+  StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64 handle,
+                                              PostorderDFSNodeType type);
+  StatusOr<PostorderDFSNode> AnalyzeConstant(int64 handle);
+  StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(int64 handle,
+                                                  PostorderDFSNodeType type);
+
+  StatusOr<Literal> PostOrderDFSVisit(int64 handle, PostorderDFSNodeType type);
+
+  // Returns true if a value represented by `handle` is an integeral type or
+  // just got converted from an integral type to floating point type.
+  bool IsValueEffectiveInteger(int64 handle) {
+    const HloInstructionProto* instr = handle_to_instruction(handle);
+    if (primitive_util::IsIntegralType(instr->shape().element_type())) {
+      return true;
+    }
+    // Also returns true if this is a convert that converts an integer to float.
+    HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
+    if (opcode != HloOpcode::kConvert) {
+      return false;
+    }
+    const HloInstructionProto* parent =
+        handle_to_instruction(instr->operand_ids(0));
+    if (primitive_util::IsIntegralType(parent->shape().element_type())) {
+      return true;
+    }
+    return false;
+  }
+
+  absl::flat_hash_map<std::pair<int64, PostorderDFSNodeType>, Literal>
+      evaluated;
+  HandleToInstruction handle_to_instruction;
+  HandleToComputation handle_to_computation;
+};
+
 }  // namespace
 
-StatusOr<Literal> ValueInference::AnalyzeConstantLiteral(int64 handle) {
-  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
-                      builder_->LookUpInstructionByHandle(handle));
+// Analyze a tensor's constant value, upper-bound value or lower-bound value.
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
+    int64 handle, PostorderDFSNodeType type) {
+  const HloInstructionProto* root = handle_to_instruction(handle);
   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+  PostorderDFSNode result;
+  for (auto operand_id : root->operand_ids()) {
+    result.AddDependency(operand_id, type);
+  }
   switch (opcode) {
-    case HloOpcode::kGetDimensionSize: {
-      int64 dimension = root->dimensions(0);
-      int64 operand_handle = root->operand_ids(0);
-      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
-                          builder_->LookUpInstructionByHandle(operand_handle));
-      if (operand_proto->shape().is_dynamic_dimension(dimension)) {
-        // The value is dynamic. We return a garbage literal here, which
-        // will be masked out later.
-        return CreateGarbageLiteral(Shape(root->shape()));
-      } else {
-        return LiteralUtil::CreateR0<int32>(
-            operand_proto->shape().dimensions(dimension));
-      }
-    }
       // Non functional ops.
     case HloOpcode::kRng:
     case HloOpcode::kAllReduce:
-      // TODO(b/33009255): Implement constant folding for cross replica sum.
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
     case HloOpcode::kCall:
-      // TODO(b/32495713): We aren't checking the to_apply computation itself,
-      // so we conservatively say that computations containing the Call op
-      // cannot be constant.  We cannot set is_functional=false in other similar
-      // cases since we're already relying on IsConstant to return true.
     case HloOpcode::kCustomCall:
     case HloOpcode::kWhile:
     case HloOpcode::kConditional:
-      // TODO(b/32495713): We aren't checking the condition and body
-      // computations themselves.
     case HloOpcode::kSend:
     case HloOpcode::kRecv:
     case HloOpcode::kParameter: {
-      // The value is dynamic. We return a garbage literal here, which
-      // will be masked out later.
-      return CreateGarbageLiteral(Shape(root->shape()));
+      return result.AddVisit([root](absl::Span<Literal>) {
+        // The value is dynamic. We return a garbage literal here, which
+        // will be masked out later.
+        return CreateGarbageLiteral(Shape(root->shape()));
+      });
+    }
+    // Subtract and Divide use lower-bound as second operand.
+    case HloOpcode::kSubtract:
+    case HloOpcode::kCos:
+    case HloOpcode::kSin:
+    case HloOpcode::kNegate:
+    case HloOpcode::kAbs:
+    case HloOpcode::kDivide:
+    case HloOpcode::kGetDimensionSize: {
+      return InvalidArgument("AnalyzeConstantValue can't handle opcode: %s",
+                             root->opcode());
     }
     case HloOpcode::kGetTupleElement: {
       int64 operand_handle = root->operand_ids(0);
-      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
-                          builder_->LookUpInstructionByHandle(operand_handle));
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
       TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
                           StringToHloOpcode(operand_proto->opcode()));
       if (operand_opcode == HloOpcode::kParameter) {
-        // Don't materialize the whole parameter if it's followed by a GTE.
-        return CreateGarbageLiteral(Shape(root->shape()));
+        return PostorderDFSNode().AddVisit([root](absl::Span<Literal>) {
+          // The value is dynamic. We return a garbage literal here, which
+          // will be masked out later.
+          return CreateGarbageLiteral(Shape(root->shape()));
+        });
       }
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeConstant(operand_handle);
-                               })
-          .WithPrimitiveType(PRED)
-          .Evaluate();
+
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        return HloProtoEvaluator(*root)
+            .WithOperands(operands)
+            .Evaluate();
+      });
     }
     case HloOpcode::kReduce:
     case HloOpcode::kScatter:
     case HloOpcode::kReduceWindow: {
-      HloComputationProto computation_proto =
-          builder_->embedded_[root->called_computation_ids(0)];
-      TF_ASSIGN_OR_RETURN(auto computation, HloComputation::CreateFromProto(
-                                                computation_proto, {}));
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeConstant(operand_handle);
-                               })
-          .WithComputation(std::move(computation))
-          .Evaluate();
+      const HloComputationProto* computation_proto =
+          handle_to_computation(root->called_computation_ids(0));
+      return result.AddVisit(
+          [root, computation_proto](
+              absl::Span<Literal> operands) -> StatusOr<Literal> {
+            TF_ASSIGN_OR_RETURN(
+                auto computation,
+                HloComputation::CreateFromProto(*computation_proto, {}));
+            return HloProtoEvaluator(*root)
+                .WithOperands(operands)
+                .WithComputation(std::move(computation))
+                .Evaluate();
+          });
     }
-    default:
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeConstant(operand_handle);
-                               })
-          .Evaluate();
+    default: {
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+      });
+    }
   }
 }
 
-StatusOr<Literal> ValueInference::AnalyzeIsDynamicLiteral(int64 handle) {
-  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
-                      builder_->LookUpInstructionByHandle(handle));
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
+    int64 handle) {
+  const HloInstructionProto* root = handle_to_instruction(handle);
   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
   switch (opcode) {
     case HloOpcode::kGetDimensionSize: {
       int64 dimension = root->dimensions(0);
       int64 operand_handle = root->operand_ids(0);
-      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
-                          builder_->LookUpInstructionByHandle(operand_handle));
-      return LiteralUtil::CreateR0<bool>(
-          operand_proto->shape().is_dynamic_dimension(dimension));
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
+      return PostorderDFSNode().AddVisit(
+          [operand_proto, dimension]() -> StatusOr<Literal> {
+            return LiteralUtil::CreateR0<int32>(
+                operand_proto->shape().dimensions(dimension));
+          });
+    }
+    case HloOpcode::kAbs: {
+      // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
+      //                                 abs(upper-bound(operand)))
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantLowerBound)
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantUpperBound)
+          .AddVisit([](Literal lower_bound,
+                       Literal upper_bound) -> StatusOr<Literal> {
+            HloEvaluator evaluator;
+            TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
+                                evaluator.EvaluateElementwiseUnaryOp(
+                                    HloOpcode::kAbs, lower_bound));
+            TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
+                                evaluator.EvaluateElementwiseUnaryOp(
+                                    HloOpcode::kAbs, upper_bound));
+            return evaluator.EvaluateElementwiseBinaryOp(
+                HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
+          });
+    }
+    case HloOpcode::kNegate: {
+      // upper-bound(negate(operand)) = negate(lower-bound(operand))
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantLowerBound)
+          .AddVisit([](Literal lower_bound) -> StatusOr<Literal> {
+            HloEvaluator evaluator;
+            return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
+                                                        lower_bound);
+          });
+    }
+    case HloOpcode::kSubtract:
+    case HloOpcode::kDivide: {
+      // Lower-bound is used for second operand of subtract and divide.
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantUpperBound)
+          .AddDependency(root->operand_ids(1),
+                         PostorderDFSNodeType::kConstantLowerBound)
+          .AddVisit(
+              [root, opcode, this](Literal upper_bound,
+                                   Literal lower_bound) -> StatusOr<Literal> {
+                if (opcode == HloOpcode::kDivide &&
+                    this->IsValueEffectiveInteger(root->operand_ids(1))) {
+                  // Because in many cases the lower bound of a value is
+                  // integer 0, instead of throwing an divide-by-zero error
+                  // at compile time, we set the bound defer the check to
+                  // runtime. In those cases we use the upper-bound of
+                  // first operand as a placeholder.
+                  HloEvaluator evaluator;
+                  auto zero =
+                      LiteralUtil::Zero(lower_bound.shape().element_type());
+                  zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
+                  TF_ASSIGN_OR_RETURN(
+                      auto lower_bound_is_zero,
+                      evaluator.EvaluateElementwiseCompareOp(
+                          ComparisonDirection::kEq, lower_bound, zero));
+
+                  auto one =
+                      LiteralUtil::One(lower_bound.shape().element_type());
+                  one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
+                  TF_ASSIGN_OR_RETURN(
+                      lower_bound, evaluator.EvaluateElementwiseTernaryOp(
+                                       HloOpcode::kSelect, lower_bound_is_zero,
+                                       one, lower_bound));
+                }
+                std::vector<Literal> new_operands;
+                new_operands.emplace_back(std::move(upper_bound));
+                new_operands.emplace_back(std::move(lower_bound));
+                return HloProtoEvaluator(*root)
+                    .WithOperands(absl::MakeSpan(new_operands))
+                    .Evaluate();
+              });
+    }
+    default:
+      return AnalyzeConstantValueFallback(
+          handle, PostorderDFSNodeType::kConstantUpperBound);
+  }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
+    int64 handle) {
+  const HloInstructionProto* root = handle_to_instruction(handle);
+  TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+  switch (opcode) {
+    case HloOpcode::kGetDimensionSize: {
+      int64 dimension = root->dimensions(0);
+      int64 operand_handle = root->operand_ids(0);
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
+      return PostorderDFSNode().AddVisit(
+          [dimension, operand_proto]() -> StatusOr<Literal> {
+            if (operand_proto->shape().is_dynamic_dimension(dimension)) {
+              return LiteralUtil::CreateR0<int32>(0);
+            } else {
+              return LiteralUtil::CreateR0<int32>(
+                  operand_proto->shape().dimensions(dimension));
+            }
+          });
+    }
+    case HloOpcode::kAbs: {
+      // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
+      // abs(upper-bound(operand)))
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantLowerBound)
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantUpperBound)
+          .AddVisit([](Literal lower_bound,
+                       Literal upper_bound) -> StatusOr<Literal> {
+            HloEvaluator evaluator;
+            TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
+                                evaluator.EvaluateElementwiseUnaryOp(
+                                    HloOpcode::kAbs, lower_bound));
+            TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
+                                evaluator.EvaluateElementwiseUnaryOp(
+                                    HloOpcode::kAbs, upper_bound));
+            return evaluator.EvaluateElementwiseBinaryOp(
+                HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
+          });
+    }
+    case HloOpcode::kNegate: {
+      // lower-bound(negate(operand)) = negate(upper-bound(operand))
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantUpperBound)
+          .AddVisit([](Literal upper_bound) -> StatusOr<Literal> {
+            HloEvaluator evaluator;
+            return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
+                                                        upper_bound);
+          });
+    }
+    case HloOpcode::kSubtract:
+    case HloOpcode::kDivide: {
+      // Upper bound is used for second operand of subtract and divide.
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantLowerBound)
+          .AddDependency(root->operand_ids(1),
+                         PostorderDFSNodeType::kConstantUpperBound)
+          .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+            return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+          });
+    }
+    default:
+      return AnalyzeConstantValueFallback(
+          handle, PostorderDFSNodeType::kConstantLowerBound);
+  }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
+    int64 handle) {
+  const HloInstructionProto* root = handle_to_instruction(handle);
+  HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
+  switch (opcode) {
+    case HloOpcode::kGetDimensionSize: {
+      int64 dimension = root->dimensions(0);
+      int64 operand_handle = root->operand_ids(0);
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
+      return PostorderDFSNode().AddVisit(
+          [operand_proto, dimension, root]() -> StatusOr<Literal> {
+            if (operand_proto->shape().is_dynamic_dimension(dimension)) {
+              // The value is dynamic, we return garbage data here and mask them
+              // out later.
+              return CreateGarbageLiteral(Shape(root->shape()));
+            } else {
+              return LiteralUtil::CreateR0<int32>(
+                  operand_proto->shape().dimensions(dimension));
+            }
+          });
+    }
+    case HloOpcode::kSubtract:
+    case HloOpcode::kCos:
+    case HloOpcode::kSin:
+    case HloOpcode::kNegate:
+    case HloOpcode::kAbs:
+    case HloOpcode::kDivide: {
+      PostorderDFSNode result;
+      for (auto operand_id : root->operand_ids()) {
+        result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue);
+      }
+      return result.AddVisit(
+          [root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+            return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+          });
+    }
+    default:
+      return AnalyzeConstantValueFallback(handle,
+                                          PostorderDFSNodeType::kConstantValue);
+  }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
+    int64 handle, PostorderDFSNodeType type) {
+  const HloInstructionProto* root = handle_to_instruction(handle);
+  // Invariant check.
+  TF_RET_CHECK(root);
+  TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+  PostorderDFSNode result;
+  for (auto operand_id : root->operand_ids()) {
+    result.AddDependency(operand_id, type);
+  }
+  switch (opcode) {
+    case HloOpcode::kGetDimensionSize: {
+      int64 dimension = root->dimensions(0);
+      int64 operand_handle = root->operand_ids(0);
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
+      return PostorderDFSNode().AddVisit([operand_proto, dimension,
+                                          type]() -> StatusOr<Literal> {
+        if (type == PostorderDFSNodeType::kBoundIsDynamic) {
+          // The bound of dynamic dimension is not dynamic.
+          return LiteralUtil::CreateR0<bool>(false);
+        }
+        // The value of dynamic dimension is dynamic.
+        return LiteralUtil::CreateR0<bool>(
+            operand_proto->shape().is_dynamic_dimension(dimension));
+      });
     }
     case HloOpcode::kAbs:
     case HloOpcode::kRoundNearestAfz:
@@ -274,9 +626,7 @@
     case HloOpcode::kCbrt:
     case HloOpcode::kTanh: {
       // Forward operand as they don't change if a value is dynamic or static.
-      int64 operand_handle = root->operand_ids(0);
-      TF_ASSIGN_OR_RETURN(auto literal, AnalyzeIsDynamic(operand_handle));
-      return literal.Clone();
+      return result.AddVisit([](Literal operand) { return operand; });
     }
     case HloOpcode::kAdd:
     case HloOpcode::kAtan2:
@@ -295,13 +645,13 @@
     case HloOpcode::kShiftLeft:
     case HloOpcode::kShiftRightArithmetic:
     case HloOpcode::kShiftRightLogical: {
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeIsDynamic(operand_handle);
-                               })
-          .WithPrimitiveType(PRED)
-          .WithOpCode(HloOpcode::kOr)
-          .Evaluate();
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        return HloProtoEvaluator(*root)
+            .WithOperands(operands)
+            .WithPrimitiveType(PRED)
+            .WithOpCode(HloOpcode::kOr)
+            .Evaluate();
+      });
     }
     case HloOpcode::kTuple:
     case HloOpcode::kTranspose:
@@ -310,112 +660,135 @@
     case HloOpcode::kConcatenate:
     case HloOpcode::kReshape:
     case HloOpcode::kPad: {
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeIsDynamic(operand_handle);
-                               })
-          .WithPrimitiveType(PRED)
-          .Evaluate();
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        return HloProtoEvaluator(*root)
+            .WithOperands(operands)
+            .WithPrimitiveType(PRED)
+            .Evaluate();
+      });
     }
     case HloOpcode::kGetTupleElement: {
       int64 operand_handle = root->operand_ids(0);
-      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
-                          builder_->LookUpInstructionByHandle(operand_handle));
+      const HloInstructionProto* operand_proto =
+          handle_to_instruction(operand_handle);
       TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
                           StringToHloOpcode(operand_proto->opcode()));
       if (operand_opcode == HloOpcode::kParameter) {
-        // Don't materialize the whole parameter if it's followed by a GTE.
-        return CreatePredLiteral(true, Shape(root->shape()));
+        PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
+          // Don't materialize the whole parameter if it's followed by a GTE.
+          return CreatePredLiteral(true, Shape(root->shape()));
+        });
       }
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeIsDynamic(operand_handle);
-                               })
-          .WithPrimitiveType(PRED)
-          .Evaluate();
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        return HloProtoEvaluator(*root)
+            .WithOperands(operands)
+            .WithPrimitiveType(PRED)
+            .Evaluate();
+      });
     }
 
     case HloOpcode::kReduce: {
-      std::vector<std::unique_ptr<HloInstruction>> operand_storage;
-      absl::flat_hash_map<int64, HloInstruction*> operand_map;
-      absl::flat_hash_map<int64, HloComputation*> computation_map;
-
-      Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
-      HloComputation::Builder b("reduce_or");
-      auto lhs = b.AddInstruction(
-          HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
-      auto rhs = b.AddInstruction(
-          HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
-      b.AddInstruction(
-          HloInstruction::CreateBinary(scalar_shape, HloOpcode::kOr, lhs, rhs));
-      auto reduce_computation = b.Build();
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 return AnalyzeIsDynamic(operand_handle);
-                               })
-          .WithPrimitiveType(PRED)
-          .WithComputation(std::move(reduce_computation))
-          .Evaluate();
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
+        HloComputation::Builder b("reduce_or");
+        auto lhs = b.AddInstruction(
+            HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+        auto rhs = b.AddInstruction(
+            HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+        b.AddInstruction(HloInstruction::CreateBinary(
+            scalar_shape, HloOpcode::kOr, lhs, rhs));
+        auto reduce_computation = b.Build();
+        return HloProtoEvaluator(*root)
+            .WithOperands(operands)
+            .WithPrimitiveType(PRED)
+            .WithComputation(std::move(reduce_computation))
+            .Evaluate();
+      });
     }
     case HloOpcode::kConstant:
     case HloOpcode::kIota: {
-      return CreatePredLiteral(false, Shape(root->shape()));
+      return result.AddVisit(
+          [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
     }
     case HloOpcode::kParameter: {
-      return CreatePredLiteral(true, Shape(root->shape()));
+      return result.AddVisit(
+          [root]() { return CreatePredLiteral(true, Shape(root->shape())); });
     }
     case HloOpcode::kSelect: {
-      TF_ASSIGN_OR_RETURN(OptionaLiteralSlice optional_selector_literal,
-                          AnalyzeOptionalConstant(root->operand_ids(0)));
-      TF_ASSIGN_OR_RETURN(LiteralSlice lhs,
-                          AnalyzeIsDynamic(root->operand_ids(1)));
-      TF_ASSIGN_OR_RETURN(LiteralSlice rhs,
-                          AnalyzeIsDynamic(root->operand_ids(2)));
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kConstantValue)
+          .AddDependency(root->operand_ids(0),
+                         PostorderDFSNodeType::kValueIsDynamic)
+          // lhs dependency.
+          .AddDependency(root->operand_ids(1), type)
+          // rhs dependency.
+          .AddDependency(root->operand_ids(2), type)
+          .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+            OptionalLiteral optional_selector_literal(std::move(operands[0]),
+                                                      std::move(operands[1]));
+            Literal lhs = std::move(operands[2]);
+            Literal rhs = std::move(operands[3]);
+            auto result = CreatePredLiteral(true, Shape(root->shape()));
 
-      auto result = CreatePredLiteral(true, Shape(root->shape()));
+            result.MutableEachCell<bool>(
+                [&](absl::Span<const int64> indices, bool value) {
+                  absl::optional<bool> optional_selector =
+                      optional_selector_literal.Get<bool>(indices);
 
-      result.MutableEachCell<bool>(
-          [&](absl::Span<const int64> indices, bool value) {
-            absl::optional<bool> optional_selector =
-                optional_selector_literal.Get<bool>(indices);
-
-            bool lhs_value = lhs.Get<bool>(indices);
-            bool rhs_value = rhs.Get<bool>(indices);
-            if (optional_selector.has_value()) {
-              // Manually evaluate the selection without using Evaluator.
-              if (*optional_selector) {
-                return lhs_value;
-              } else {
-                return rhs_value;
-              }
-            } else {
-              // Conservatively assume value is dynamic if selector is dynamic.
-              return true;
-            }
+                  bool lhs_value = lhs.Get<bool>(indices);
+                  bool rhs_value = rhs.Get<bool>(indices);
+                  if (optional_selector.has_value()) {
+                    // Manually evaluate the selection without using Evaluator.
+                    if (*optional_selector) {
+                      return lhs_value;
+                    } else {
+                      return rhs_value;
+                    }
+                  } else {
+                    // Conservatively assume value is dynamic if selector is
+                    // dynamic.
+                    return true;
+                  }
+                });
+            return result;
           });
-      return result;
     }
     case HloOpcode::kGather: {
-      TF_ASSIGN_OR_RETURN(OptionaLiteralSlice optional_selector_literal,
-                          AnalyzeOptionalConstant(root->operand_ids(1)));
-      if (!optional_selector_literal.AllValid()) {
-        // Conservatively assume result are dynamic.
-        return CreatePredLiteral(true, Shape(root->shape()));
-      }
-      return HloProtoEvaluator(*root,
-                               [&](int64 operand_index, int64 operand_handle) {
-                                 if (operand_index == 1) {
-                                   return AnalyzeConstant(operand_handle);
-                                 } else {
-                                   return AnalyzeIsDynamic(operand_handle);
-                                 }
-                               })
-          .WithPrimitiveType(PRED)
-          .Evaluate();
+      return PostorderDFSNode()
+          .AddDependency(root->operand_ids(0), type)
+          .AddDependency(root->operand_ids(1),
+                         PostorderDFSNodeType::kConstantValue)
+          .AddDependency(root->operand_ids(1),
+                         PostorderDFSNodeType::kValueIsDynamic)
+          .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+            OptionalLiteral optional_selector_literal(std::move(operands[1]),
+                                                      std::move(operands[2]));
+
+            if (!optional_selector_literal.AllValid()) {
+              // Conservatively assume results are dynamic.
+              return CreatePredLiteral(true, Shape(root->shape()));
+            }
+            std::vector<Literal> new_operands;
+            new_operands.emplace_back(std::move(operands[0]));
+            new_operands.emplace_back(
+                optional_selector_literal.GetValue()->Clone());
+
+            return HloProtoEvaluator(*root)
+                .WithOperands(absl::MakeSpan(new_operands))
+                .WithPrimitiveType(PRED)
+                .Evaluate();
+          });
     }
     case HloOpcode::kCustomCall: {
       if (root->custom_call_target() == "SetBound") {
-        return CreatePredLiteral(true, Shape(root->shape()));
+        return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
+          if (type == PostorderDFSNodeType::kBoundIsDynamic) {
+            return CreatePredLiteral(false, Shape(root->shape()));
+          } else {
+            return CreatePredLiteral(true, Shape(root->shape()));
+          }
+        });
       } else {
         return InvalidArgument(
             "Dynamic inferencing on custom call %s is not supported",
@@ -430,29 +803,138 @@
   }
 }
 
-StatusOr<LiteralSlice> ValueInference::AnalyzeIsDynamic(int64 handle) {
-  if (is_dynamic_.contains(handle)) {
-    return LiteralSlice(is_dynamic_[handle]);
+StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
+    int64 handle, PostorderDFSNodeType type) {
+  enum VisitState {
+    kUnvisited = 0,
+    kVisiting,
+    kVisited,
+  };
+
+  struct WorkItem {
+    explicit WorkItem(int64 handle, PostorderDFSNodeType type, VisitState state)
+        : handle(handle), type(type), state(state) {}
+    int64 handle;
+    PostorderDFSNodeType type;
+    VisitState state;
+    PostorderDFSNode node;
+  };
+
+  std::vector<WorkItem> stack;
+  stack.push_back(WorkItem(handle, type, kUnvisited));
+  while (!stack.empty()) {
+    WorkItem& item = stack.back();
+    VLOG(1) << "stack top" << handle_to_instruction(item.handle)->DebugString();
+    if (item.state == kVisiting) {
+      VLOG(1) << "visiting";
+      // The operands are ready, visit the node itself.
+
+      // Gather dependencies.
+      std::vector<Literal> literals;
+      for (const PostorderDFSDep& dep : item.node.dependencies) {
+        std::pair<int64, PostorderDFSNodeType> key(dep.handle, dep.type);
+        TF_RET_CHECK(evaluated.contains(key));
+        literals.emplace_back(evaluated.at(key).Clone());
+      }
+      VLOG(1) << "start visiting";
+      TF_ASSIGN_OR_RETURN(auto literal,
+                          item.node.visit(absl::MakeSpan(literals)));
+      VLOG(1) << "end visiting: " << literal.ToString();
+      std::pair<int64, PostorderDFSNodeType> key(item.handle, item.type);
+      evaluated[key] = std::move(literal);
+      stack.pop_back();
+      continue;
+    }
+    // This is the first time we see this node, we want to gather its
+    // dependenceis.
+    VLOG(1) << "unvisited";
+    item.state = kVisiting;
+    PostorderDFSNode node;
+    switch (item.type) {
+      case PostorderDFSNodeType::kConstantValue: {
+      TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
+        break;
+      }
+      case PostorderDFSNodeType::kConstantLowerBound: {
+      TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
+        break;
+      }
+      case PostorderDFSNodeType::kConstantUpperBound: {
+        TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle));
+        break;
+      }
+      case PostorderDFSNodeType::kBoundIsDynamic:
+      case PostorderDFSNodeType::kValueIsDynamic: {
+        TF_ASSIGN_OR_RETURN(node, AnalyzeIsDynamic(item.handle, item.type));
+        break;
+      }
+    }
+    // Store the node which is needed when its dependencies are resolved.
+    item.node = node;
+    // Enqueue dependencies into the stack.
+    for (const PostorderDFSDep& dep : node.dependencies) {
+      VLOG(1) << "dep" << handle_to_instruction(dep.handle)->DebugString();
+      stack.emplace_back(dep.handle, dep.type, kUnvisited);
+    }
   }
-  TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeIsDynamicLiteral(handle));
-  is_dynamic_[handle] = std::move(literal);
-  return LiteralSlice(is_dynamic_[handle]);
+  VLOG(1) << "done" << evaluated[std::make_pair(handle, type)].ToString();
+  return evaluated[std::make_pair(handle, type)].Clone();
 }
 
-StatusOr<LiteralSlice> ValueInference::AnalyzeConstant(int64 handle) {
-  if (constant_.contains(handle)) {
-    return LiteralSlice(constant_[handle]);
-  }
-  TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeConstantLiteral(handle));
-  constant_[handle] = std::move(literal);
-  return LiteralSlice(constant_[handle]);
+StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
+  PostorderDFSVisitor visitor(
+      [&](int64 handle) {
+        return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
+      },
+      [&](int64 handle) { return &(builder_->embedded_[handle]); });
+  return visitor.PostOrderDFSVisit(op.handle(),
+                                   PostorderDFSNodeType::kValueIsDynamic);
 }
 
-StatusOr<OptionaLiteralSlice> ValueInference::AnalyzeOptionalConstant(
-    int64 handle) {
-  TF_ASSIGN_OR_RETURN(LiteralSlice value, AnalyzeConstant(handle));
-  TF_ASSIGN_OR_RETURN(LiteralSlice mask, AnalyzeIsDynamic(handle));
-  return OptionaLiteralSlice(value, mask);
+StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
+    XlaOp op, ValueInferenceMode mode) {
+  PostorderDFSVisitor visitor(
+      [&](int64 handle) {
+        return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
+      },
+      [&](int64 handle) { return &(builder_->embedded_[handle]); });
+  switch (mode) {
+    case ValueInferenceMode::kLowerBound: {
+      TF_ASSIGN_OR_RETURN(
+          Literal value,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kConstantLowerBound));
+      TF_ASSIGN_OR_RETURN(
+          Literal mask,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kBoundIsDynamic));
+      return OptionalLiteral(std::move(value), std::move(mask));
+    }
+
+    case ValueInferenceMode::kUpperBound: {
+      TF_ASSIGN_OR_RETURN(
+          Literal value,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kConstantUpperBound));
+      TF_ASSIGN_OR_RETURN(
+          Literal mask,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kBoundIsDynamic));
+
+      return OptionalLiteral(std::move(value), std::move(mask));
+    }
+    case ValueInferenceMode::kValue: {
+      TF_ASSIGN_OR_RETURN(
+          Literal value,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kConstantValue));
+      TF_ASSIGN_OR_RETURN(
+          Literal mask,
+          visitor.PostOrderDFSVisit(op.handle(),
+                                    PostorderDFSNodeType::kValueIsDynamic));
+      return OptionalLiteral(std::move(value), std::move(mask));
+    }
+  }
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/value_inference.h b/tensorflow/compiler/xla/client/value_inference.h
index afe8dfc..c7eb760 100644
--- a/tensorflow/compiler/xla/client/value_inference.h
+++ b/tensorflow/compiler/xla/client/value_inference.h
@@ -28,30 +28,49 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
-// OptionaLiteralSlice is an augmented literal class which returns optional
-// values for each index (the value can be either valid or invalid). Underneath
-// it keeps two literals, a value literal, holding both the valid and garabage
-// value, and a masking litearl representing if a value is valid or garbage.
-class OptionaLiteralSlice {
+// OptionalLiteral is an augmented literal class which returns optional
+// values for each index (the value can be either valid or invalid). The
+// implementation keeps two literals, a value literal, holding both the valid
+// and garabage value, and a masking literal representing if a value is valid or
+// garbage.
+class OptionalLiteral {
  public:
-  explicit OptionaLiteralSlice(LiteralSlice value, LiteralSlice mask)
-      : value_(value), mask_(mask) {}
+  explicit OptionalLiteral(Literal value, Literal mask)
+      : value_(std::move(value)), mask_(std::move(mask)) {}
 
   template <typename NativeT>
-  absl::optional<NativeT> Get(absl::Span<const int64> multi_index) const {
-    if (mask_.Get<bool>(multi_index)) {
+  absl::optional<NativeT> Get(absl::Span<const int64> element_index,
+                              ShapeIndex shape_index = {}) const {
+    if (mask_.Get<bool>(element_index, shape_index)) {
       return absl::nullopt;
     } else {
-      return value_.Get<NativeT>(multi_index);
+      return value_.Get<NativeT>(element_index, shape_index);
     }
   }
 
   // Returns true if all values in this literal slice are value.
   bool AllValid() { return mask_.IsAll(0); }
 
+  // Get value out of this slice if all values are valid. Otherwise returns
+  // nullopt.
+  absl::optional<LiteralSlice> GetValue() {
+    if (!AllValid()) {
+      return absl::nullopt;
+    }
+    return LiteralSlice(value_);
+  }
+
  private:
-  LiteralSlice value_;
-  LiteralSlice mask_;
+  Literal value_;
+  Literal mask_;
+};
+
+enum ValueInferenceMode {
+  // Inference the constant value itself.
+  kValue = 0,
+  // Inference upper-bound and lower-bound of the value. Bounds are inclusive.
+  kUpperBound,
+  kLowerBound,
 };
 
 class ValueInference {
@@ -61,38 +80,16 @@
   // - What's the lower-bound of each value in a tensor.
   // - What's the constant value of each tensor.
   // - Whether or not each value in a tensor is dynamic.
-  explicit ValueInference(XlaBuilder* builder) : builder_(builder) {}
-  StatusOr<LiteralSlice> AnalyzeUpperBound(XlaOp op) {
-    return Unimplemented("Analyzing upper-bound is not implemented yet.");
+  explicit ValueInference(XlaBuilder* builder) : builder_(builder) {
+    CHECK(builder_);
   }
-  StatusOr<LiteralSlice> AnalyzeLowerBound(XlaOp op) {
-    return Unimplemented("Analyzing lower-bound is not implemented yet.");
-  }
-  StatusOr<LiteralSlice> AnalyzeIsDynamic(XlaOp op) {
-    return AnalyzeIsDynamic(op.handle());
-  }
-
-  // Returns a OptionalConstant, the value is nullopt it's dynamic, otherwise a
-  // concrete constant value.
-  StatusOr<OptionaLiteralSlice> AnalyzeOptionalConstant(XlaOp op) {
-    return AnalyzeOptionalConstant(op.handle());
-  }
+  StatusOr<Literal> AnalyzeIsDynamic(XlaOp op);
+  // Returns an OptionalLiteral. Each individual value of the literal is
+  // the concrete constant value if it can be inferred, otherwise a nullopt.
+  StatusOr<OptionalLiteral> AnalyzeConstant(XlaOp op, ValueInferenceMode mode);
 
  private:
-  StatusOr<LiteralSlice> AnalyzeIsDynamic(int64 handle);
-  StatusOr<LiteralSlice> AnalyzeConstant(int64 handle);
-  StatusOr<OptionaLiteralSlice> AnalyzeOptionalConstant(int64 handle);
-
-  StatusOr<Literal> AnalyzeIsDynamicLiteral(int64 handle);
-  StatusOr<Literal> AnalyzeConstantLiteral(int64 handle);
-
   XlaBuilder* builder_;
-  // Cache to avoid re-evaluating. Mapping of xla handle to evaluated
-  // literals.
-  absl::flat_hash_map<int64, Literal> upper_bound_;
-  absl::flat_hash_map<int64, Literal> lower_bound_;
-  absl::flat_hash_map<int64, Literal> is_dynamic_;
-  absl::flat_hash_map<int64, Literal> constant_;
   HloEvaluator evaluator_;
 };
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 43e5d41..6b95fdf 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -303,7 +303,7 @@
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
         "//tensorflow/core:lib",
-        "//third_party/eigen3",
+        "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:inlined_vector",
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 9816a2d..aa3150e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -52,6 +52,7 @@
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace xla {
 
@@ -367,6 +368,40 @@
   return result;
 }
 
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
+    HloOpcode opcode, const Literal& lhs, const Literal& rhs,
+    const Literal& ehs) {
+  std::unique_ptr<HloInstruction> lhs_instr =
+      HloInstruction::CreateConstant(lhs.Clone());
+  std::unique_ptr<HloInstruction> rhs_instr =
+      HloInstruction::CreateConstant(rhs.Clone());
+  std::unique_ptr<HloInstruction> ehs_instr =
+      HloInstruction::CreateConstant(ehs.Clone());
+  TF_ASSIGN_OR_RETURN(auto output_shape,
+                      ShapeInference::InferTernaryOpShape(
+                          opcode, lhs.shape(), rhs.shape(), ehs.shape()));
+  std::unique_ptr<HloInstruction> cloned_instruction =
+      HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
+                                    rhs_instr.get(), ehs_instr.get());
+  return Evaluate(cloned_instruction.get());
+}
+
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
+    ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
+  std::unique_ptr<HloInstruction> lhs_instr =
+      HloInstruction::CreateConstant(lhs.Clone());
+  std::unique_ptr<HloInstruction> rhs_instr =
+      HloInstruction::CreateConstant(rhs.Clone());
+
+  std::unique_ptr<HloInstruction> cloned_instruction =
+      HloInstruction::CreateCompare(
+          ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
+          rhs_instr.get(), direction);
+  auto result = Evaluate(cloned_instruction.get());
+
+  return result;
+}
+
 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
     HloOpcode opcode, const Literal& operand) {
   std::unique_ptr<HloInstruction> operand_instr =
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 0d64582..a891312 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -124,6 +124,15 @@
   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
                                                const Literal& operand);
 
+  StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode,
+                                                 const Literal& lhs,
+                                                 const Literal& rhs,
+                                                 const Literal& ehs);
+
+  StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction,
+                                                 const Literal& lhs,
+                                                 const Literal& rhs);
+
   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
                                   const PrecisionConfig& precision_config,
                                   const Literal& lhs, const Literal& rhs);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 822dec1..806d5bd 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2141,6 +2141,7 @@
         "//tensorflow/compiler/xla/client/lib:prng",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index 972b237..a5019fa 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -43,35 +43,17 @@
 namespace xla {
 namespace {
 
-// An enumerator for the client types that we want to iterate over in
-// the various tests.
-enum class ClientType { kLocal, kCompileOnly };
-
-class DynamismInferenceTest : public ::testing::Test {
+class ValueInferenceTest : public ::testing::Test {
  public:
-  explicit DynamismInferenceTest(se::Platform* platform = nullptr)
-      : platform_(platform) {}
-
   string TestName() const {
     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
   }
+};
 
-  Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
-    if (client_type == ClientType::kLocal) {
-      StatusOr<Client*> result =
-          ClientLibrary::GetOrCreateLocalClient(platform);
-      TF_CHECK_OK(result.status())
-          << "could not create LocalClient for testing";
-      return result.ValueOrDie();
-    } else if (client_type == ClientType::kCompileOnly) {
-      StatusOr<Client*> result =
-          ClientLibrary::GetOrCreateCompileOnlyClient(platform);
-      TF_CHECK_OK(result.status())
-          << "could not create CompileOnlyClient for testing";
-      return result.ValueOrDie();
-    }
-    LOG(FATAL) << "invalid client_type value";
-  }
+class DynamismInferenceTest : public ValueInferenceTest {
+ public:
+  explicit DynamismInferenceTest(se::Platform* platform = nullptr)
+      : platform_(platform) {}
 
   StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
                                            Layout* output_layout = nullptr) {
@@ -316,5 +298,80 @@
   EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
 }
 
+class UpperBoundInferenceTest : public ValueInferenceTest {
+ public:
+  explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
+      : platform_(platform) {}
+
+  StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
+      XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
+    ValueInference value_inference(builder);
+    TF_ASSIGN_OR_RETURN(auto literal,
+                        value_inference.AnalyzeConstant(
+                            operand, ValueInferenceMode::kUpperBound));
+    return literal;
+  }
+
+  se::Platform* platform_;
+};
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
+  XlaBuilder b(TestName());
+  auto p =
+      Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+
+  auto gds0 = GetDimensionSize(p, 0);
+  auto gds1 = GetDimensionSize(p, 1);
+  auto tuple_2 = Tuple(&b, {gds0, gds1});
+  EXPECT_EQ(
+      ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {0}),
+      2);
+  EXPECT_EQ(
+      ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {1}),
+      3);
+}
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
+  XlaBuilder b(TestName());
+  auto p =
+      Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+
+  // The range of the first dimension is [0, 2]
+  auto gds0 = GetDimensionSize(p, 0);
+  // The range of the second dimension is [3, 3]
+  auto gds1 = GetDimensionSize(p, 1);
+  // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
+  auto sub = Sub(gds1, gds0);
+  EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
+}
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
+  XlaBuilder b(TestName());
+  auto p =
+      Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+  // The range of the first dimension is [0, 2]
+  auto gds0 = GetDimensionSize(p, 0);
+  // The range of the second dimension is [3, 3]
+  auto gds1 = GetDimensionSize(p, 1);
+  // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
+  // don't use 0 as the lower bound as it would create divide-by-zero error.
+  auto sub = Div(gds1, gds0);
+  EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
+}
+
+TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
+  // We can infer a parameter's dimension's bound, but not the parameter value's
+  // bound.
+  XlaBuilder b(TestName());
+  auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
+  auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
+  auto gds = GetDimensionSize(p0, 0);
+  auto sub = Div(gds, p1);
+  EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
+                   .ValueOrDie()
+                   .Get<int32>({})
+                   .has_value());
+}
+
 }  // namespace
 }  // namespace xla