Start using value inference in constant value inference.

- Change rest of the kernels that support dynamic shapes to use value inference.
- Update pad_op so that it can infer a tigher bound using value inference.
- Update slice_op so it doesn't create a shape with all dynamic dimensions when some slice sizes are constant.

PiperOrigin-RevId: 391081565
Change-Id: Iccfbdd8b899bea8dbc42980c27b6c995233b3ccd
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index f62323f..08c8e21 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -182,6 +182,7 @@
         "//tensorflow/compiler/xla/client/lib:arithmetic",
         "//tensorflow/compiler/xla/client/lib:comparators",
         "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/compiler/xla/client/lib:dynamic_shaped_ops",
         "//tensorflow/compiler/xla/client/lib:loops",
         "//tensorflow/compiler/xla/client/lib:math",
         "//tensorflow/compiler/xla/client/lib:matrix",
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index 7ada23a..4c6cb3b 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
@@ -46,7 +47,8 @@
                   errors::InvalidArgument("In[", i, "] must be a vector.",
                                           in_shape.DebugString()));
       std::vector<int64_t> shape;
-      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &shape));
+      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(
+                              i, &shape, xla::ValueInferenceMode::kUpperBound));
       shapes.push_back(BCast::Vec(shape.begin(), shape.end()));
     }
     BCast bcast(shapes[0], shapes[1]);
@@ -95,7 +97,11 @@
                   errors::InvalidArgument("In[", i, "] must be a vector.",
                                           in_shape.DebugString()));
       std::vector<int64_t> vec;
-      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec));
+      // Technically we don't need to infer the upper-bound here. However the
+      // forward path uses the upperbound as bounded shape so we need backward
+      // path to use the same shape to decide the reduction indices.
+      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(
+                              i, &vec, xla::ValueInferenceMode::kUpperBound));
 
       shapes.push_back(BCast::Vec(vec.begin(), vec.end()));
     }
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index f872361..8847692 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -30,7 +30,9 @@
 
   void Compile(XlaOpKernelContext* context) override {
     TensorShape output_shape;
-    OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+    OP_REQUIRES_OK(context,
+                   context->ConstantInputAsShape(
+                       1, &output_shape, xla::ValueInferenceMode::kUpperBound));
     auto output_status_or =
         BroadcastTo(context->Input(0), output_shape.dim_sizes());
     OP_REQUIRES_OK(context, output_status_or.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 51e6b89..e2e0fc8 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -45,7 +45,9 @@
     const xla::XlaOp& logits = ctx->Input(0);
     TensorShape logits_shape = ctx->InputShape(0);
     int64_t num_samples;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntScalar(
+                       1, &num_samples, xla::ValueInferenceMode::kUpperBound));
     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
                 errors::InvalidArgument("logits should be a matrix, got shape ",
                                         logits_shape.DebugString()));
@@ -66,7 +68,10 @@
 
     xla::Shape uniform_shape;
     int class_dimension;
-    if (num_samples != 1) {
+    bool num_samples_is_dynamic = false;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
+    if (num_samples != 1 || num_samples_is_dynamic) {
       std::array<int64_t, 3> uniform_shape_array = {
           {batch_size, num_samples, num_classes}};
       xla::PrimitiveType uniform_xla_type;
@@ -91,11 +96,9 @@
     xla::PrimitiveType type;
     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
     xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
-    bool num_samples_is_dynamic = false;
-    OP_REQUIRES_OK(
-        ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
-    if (num_samples_is_dynamic && num_samples != 1) {
-      // Number samples is dimension 1 in uniform_shape_array.
+
+    if (num_samples_is_dynamic) {
+      // num_samples is dimension 1 in uniform_shape_array.
       log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1);
     }
 
@@ -119,7 +122,7 @@
                            /*axis=*/class_dimension, /*stable=*/true);
     }
 
-    if (num_samples == 1) {
+    if (num_samples == 1 && !num_samples_is_dynamic) {
       argmax = xla::Reshape(argmax, {batch_size, 1});
     }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index fbf403b..1cb67d8 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -155,7 +155,9 @@
     const int32_t N = ctx->num_inputs() - 1;
     const TensorShape inp0_shape = ctx->InputShape(1);
     std::vector<int64_t> inp0_dims;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntVector(
+                       1, &inp0_dims, xla::ValueInferenceMode::kUpperBound));
     const int64_t inp0_rank = inp0_shape.num_elements();
 
     int64_t cdim;
@@ -174,7 +176,9 @@
                                           inp0_rank, " elements, but got ",
                                           inp_shape.num_elements()));
       std::vector<int64_t> inp_dims;
-      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims));
+      OP_REQUIRES_OK(
+          ctx, ctx->ConstantInputAsIntVector(
+                   1 + i, &inp_dims, xla::ValueInferenceMode::kUpperBound));
 
       Tensor out_constant(DT_INT32, TensorShape({inp0_rank}));
       auto out_vec = out_constant.vec<int32>();
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 53bbb86..6ebb6b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -104,7 +104,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape input_tensor_shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+    OP_REQUIRES_OK(
+        ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape,
+                                       xla::ValueInferenceMode::kUpperBound));
     xla::Shape input_shape =
         TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
     OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2,
@@ -170,7 +172,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape filter_tensor_shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+    OP_REQUIRES_OK(
+        ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape,
+                                       xla::ValueInferenceMode::kUpperBound));
     xla::Shape filter_shape =
         TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
 
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index 1c9d81d2..ebcbadb 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -45,17 +46,17 @@
                                         value_shape.DebugString()));
 
     std::vector<int64_t> dims;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims));
-    // Set dynamic dimension value to -1 so that we know which dimension is
-    // dynamic.
-    ctx->set_dynamic_dimension_is_minus_one(true);
-    std::vector<int64> dynamic_dims;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dynamic_dims));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntVector(
+                       "dims", &dims, xla::ValueInferenceMode::kUpperBound));
+    std::vector<bool> dynamic_dims;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPredVector("dims", &dynamic_dims));
 
     auto output = xla::Broadcast(ctx->Input("value"), dims);
     for (int64_t i = 0; i < dims.size(); ++i) {
       // If a dimension is dynamic, call set-dimension-size on the output.
-      if (dynamic_dims[i] == -1) {
+      if (dynamic_dims[i]) {
         auto dynamic_dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
         dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
         dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
index c00cdf5..dcd2c07 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
@@ -38,16 +38,21 @@
       VLOG(1) << "Trying to resolve constant " << i;
       // NOTE: We can not simply check that this is Kind::kConstant because
       // this could be the output of a MetadataOnly op e.g. Size.
+
+      // If we can infer the constant values of an inner computation's argument,
+      // replace them with constants. If that fails, we fallback to infer the
+      // bounds of the argument.
       StatusOr<absl::optional<Tensor>> maybe_constant =
           expression.ResolveConstant(ctx->compiler()->client());
-      if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) {
+      StatusOr<absl::optional<Tensor>> bounds =
+          expression.ResolveConstant(ctx->compiler()->client(), false,
+                                     xla::ValueInferenceMode::kUpperBound);
+      if ((maybe_constant.ok() && maybe_constant->has_value()) ||
+          (bounds.ok() && bounds->has_value())) {
         StatusOr<Tensor> values_are_dynamic =
             expression.ResolveDynamism(ctx->compiler()->client());
         bool all_values_are_static = false;
-        if (!values_are_dynamic.ok()) {
-          // Conservatiely assume all values are dynamic.
-          all_values_are_static = true;
-        } else {
+        if (values_are_dynamic.ok()) {
           xla::Literal literal =
               HostTensorToLiteral(values_are_dynamic.ValueOrDie()).ValueOrDie();
           all_values_are_static = literal.IsAll(0);
@@ -60,8 +65,7 @@
           arg->shape = expression.GetShape().ValueOrDie();
           resolved_constant_idxs.push_back(i);
         } else {
-          arg->value_bound.emplace(
-              std::move(maybe_constant.ValueOrDie().value()));
+          arg->value_bound.emplace(std::move(bounds.ValueOrDie().value()));
           arg->value_dynamism.emplace(
               std::move(values_are_dynamic.ValueOrDie()));
         }
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 792a3cb..09119a7 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -17,7 +17,9 @@
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -52,8 +54,13 @@
     }
 
     xla::Literal pad_literal;
-    OP_REQUIRES_OK(ctx,
-                   ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
+    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(
+                            "paddings", &pad_literal,
+                            xla::ValueInferenceMode::kUpperBound));
+
+    xla::Literal padding_dynamism_literal;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamism("paddings", &padding_dynamism_literal));
 
     xla::PaddingConfig config;
     for (int i = 0; i < dims; ++i) {
@@ -69,15 +76,60 @@
 
     // PadV2 added a "constant_values" input that indicates the pad value.
     xla::XlaOp constant_values;
+    xla::XlaOp pad;
     if (ctx->num_inputs() == 3) {
       OP_REQUIRES(
           ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")),
           errors::InvalidArgument("constant_values must be a scalar."));
-      ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config));
+      pad = xla::Pad(input, ctx->Input("constant_values"), config);
     } else {
       auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
-      ctx->SetOutput(0, xla::Pad(input, zero, config));
+      pad = xla::Pad(input, zero, config);
     }
+
+    for (int i = 0; i < dims; ++i) {
+      bool low_pad_is_dynamic = padding_dynamism_literal.Get<bool>({i, 0});
+
+      OP_REQUIRES(
+          ctx, !low_pad_is_dynamic,
+          errors::InvalidArgument("low_pad in Pad op has to be static."));
+      bool high_pad_is_dynamic = padding_dynamism_literal.Get<bool>({i, 1});
+      if (high_pad_is_dynamic) {
+        // When we have
+        // pad_width = MAX_WIDTH - size(t)
+        // op = pad(t, /*high_pad=*/pad_width)
+        // The bound of the result size should be MAX_WIDTH, instead of
+        // `bound(t) + bound(pad_width)`
+        //
+        // We do this by analyzing the expression
+        // size(op) = size(t) + MAX_WIDTH - size(t)
+        // and leave value inference to analyze it.
+        xla::XlaOp high_pad_size =
+            xla::Slice(ctx->Input("paddings"), {i, 1}, {i + 1, 2}, {1, 1});
+        high_pad_size = xla::Reshape(high_pad_size, {});
+        high_pad_size = xla::ConvertElementType(high_pad_size, xla::S32);
+        // Low pad has to be static.
+        xla::XlaOp low_pad_size = xla::ConstantR0<int32>(
+            ctx->builder(), pad_literal.Get<int64>({i, 0}));
+        xla::XlaOp input_size = xla::GetDimensionSize(input, i);
+        xla::XlaOp total_size = low_pad_size + input_size + high_pad_size;
+        auto size_upper_bound_status_or =
+            ctx->value_inference().AnalyzeConstant(
+                total_size, xla::ValueInferenceMode::kUpperBound);
+        OP_REQUIRES_OK(ctx, size_upper_bound_status_or.status());
+        auto size_upper_bound =
+            size_upper_bound_status_or.ValueOrDie().Get<int32>({});
+        OP_REQUIRES(
+            ctx, size_upper_bound.has_value(),
+            errors::InvalidArgument(
+                "Failed to infer upperbound of total size after padding."));
+        // If we know a tighter upperbound, trim the output with the new
+        // upperbound.
+        pad = xla::SliceInDim(pad, 0, size_upper_bound.value(), 1, i);
+        pad = xla::SetDimensionSize(pad, total_size, i);
+      }
+    }
+    ctx->SetOutput(0, pad);
   }
 };
 
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 3997e1d..546b205 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/compiler/xla/client/lib/comparators.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/lib/loops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -44,7 +45,8 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
+                            0, &shape, xla::ValueInferenceMode::kUpperBound));
 
     const DataType dtype = output_type(0);
     xla::Shape xla_shape;
@@ -58,7 +60,18 @@
         << name();
     xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
                                         XlaHelpers::One(b, dtype), xla_shape);
-
+    std::vector<bool> dynamic_dims;
+    OP_REQUIRES_OK(ctx,
+                   ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_dims));
+    for (int64_t i = 0; i < xla_shape.rank(); ++i) {
+      // If a dimension is dynamic, call set-dimension-size on the output.
+      if (dynamic_dims[i]) {
+        auto dynamic_dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
+        dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
+        dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
+        result = xla::SetDimensionSize(result, dynamic_dim_size, i);
+      }
+    }
     ctx->SetOutput(0, result);
   }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index fca4938..41f2c2e 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 
 namespace tensorflow {
@@ -57,8 +58,9 @@
     TensorShape indices_shape = ctx->InputShape(1);
 
     int64_t num_segments;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
-
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntScalar(
+                       2, &num_segments, xla::ValueInferenceMode::kUpperBound));
     OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
                 errors::InvalidArgument(type_string(),
                                         " requires that indices' rank be"
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 3d7a94e..8dc9b61 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -21,6 +21,8 @@
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -57,10 +59,11 @@
 
     std::vector<int64_t> begin;
     std::vector<int64_t> size;
-    const bool begin_is_constant =
+    const bool all_begins_are_constant =
         ctx->ConstantInputAsIntVector(1, &begin).ok();
-    const bool size_is_constant = ctx->ConstantInputAsIntVector(2, &size).ok();
-    if (begin_is_constant && size_is_constant) {
+    const bool all_sizes_are_constant =
+        ctx->ConstantInputAsIntVector(2, &size).ok();
+    if (all_begins_are_constant && all_sizes_are_constant) {
       std::vector<int64_t> wrapped_size(size.size());
       // `begin` is a compile-time constant.
       for (int i = 0; i < input_dims; ++i) {
@@ -101,12 +104,12 @@
       std::vector<int64_t> strides(begin.size(), 1);
       auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
       // Check for slice on dynamic dimensions.
-      ctx->set_dynamic_dimension_is_minus_one(true);
-      std::vector<int64> dynamic_size;
-      OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
+      std::vector<bool> size_is_dynamic;
+      OP_REQUIRES_OK(
+          ctx, ctx->ResolveInputDynamismIntoPredVector(2, &size_is_dynamic));
 
       for (int64_t i = 0; i < size.size(); ++i) {
-        if (dynamic_size[i] == -1) {
+        if (size_is_dynamic[i]) {
           if (size[i] != -1) {
             // If there is a dynamic dimension, properly set dimension size of
             // the slice.
@@ -124,7 +127,7 @@
       // This essentially makes size as dynamic.
       bool constant_size_is_minus_one = false;
       // `begin` or `size` is not a compile-time constant.
-      if (size_is_constant) {
+      if (all_sizes_are_constant) {
         for (int i = 0; i < input_dims; ++i) {
           if (size[i] < 0) {
             OP_REQUIRES(ctx, size[i] == -1,
@@ -147,9 +150,9 @@
         begin_indices.push_back(
             xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
       }
-      if (size_is_constant && !constant_size_is_minus_one) {
-        ctx->SetOutput(0,
-                       xla::DynamicSlice(ctx->Input(0), begin_indices, size));
+      if (all_sizes_are_constant && !constant_size_is_minus_one) {
+        xla::XlaOp input = ctx->Input(0);
+        ctx->SetOutput(0, xla::DynamicSlice(input, begin_indices, size));
       } else {
         // Size is not constant, use input size as upperbound and then set
         // dimension size on it.
@@ -157,16 +160,17 @@
         // First pad input with input size to avoid OOB -- dynamic slice with
         // OOB slice produces undesired results.
         xla::PaddingConfig padding_config;
+        xla::XlaOp input = ctx->Input(0);
         for (int64_t i = 0; i < input_dims; ++i) {
           auto* dims = padding_config.add_dimensions();
           dims->set_edge_padding_low(0);
           dims->set_edge_padding_high(input_shape.dim_size(i));
           dims->set_interior_padding(0);
+          input = xla::RemoveDynamicDimension(input, i);
         }
-        auto padded_input = xla::Pad(
-            ctx->Input(0), xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
-            padding_config);
-
+        auto padded_input =
+            xla::Pad(input, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
+                     padding_config);
         // Slice full size out of the input starting from the offsets.
         auto sliced = xla::DynamicSlice(padded_input, begin_indices,
                                         input_shape.dim_sizes());
@@ -179,7 +183,23 @@
                                                   input_shape.dim_size(i)) -
                            begin_indices[i];
           }
-          sliced = xla::SetDimensionSize(sliced, dynamic_size, i);
+          auto constant_size = ctx->value_inference().AnalyzeConstant(
+              dynamic_size, xla::ValueInferenceMode::kValue);
+          OP_REQUIRES_OK(ctx, constant_size.status());
+          if (constant_size->AllValid()) {
+            // Slice size on this dimension is constant. This branch is
+            // triggered when some dimensions's slice sizes are constant while
+            // some are dynamic.
+            sliced = xla::SliceInDim(
+                sliced, 0, constant_size->Get<int32>({}).value(), 1, i);
+          } else {
+            // We gave a generous bound (same as input) to the output, try reset
+            // the bound if a tighter one can be found.
+            auto status = xla::SetDimensionSizeWithRebound(
+                &ctx->value_inference(), sliced, dynamic_size, i);
+            OP_REQUIRES_OK(ctx, status.status());
+            sliced = status.ValueOrDie();
+          }
         }
         ctx->SetOutput(0, sliced);
       }
diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
index e6fe82f..bb5dfa5 100644
--- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
@@ -41,7 +41,9 @@
 
     // output_shape
     TensorShape output_shape;
-    OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+    OP_REQUIRES_OK(context,
+                   context->ConstantInputAsShape(
+                       1, &output_shape, xla::ValueInferenceMode::kUpperBound));
     OP_REQUIRES(context, output_shape.dims() == num_dims,
                 errors::InvalidArgument(
                     "output_shape has incorrect number of elements: ",
@@ -72,7 +74,20 @@
     }
     xla::XlaBuilder* builder = context->builder();
     auto buffer = Broadcast(default_value, output_shape.dim_sizes());
+    std::vector<bool> dynamic_dims;
+    OP_REQUIRES_OK(
+        context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
 
+    for (int64_t i = 0; i < dynamic_dims.size(); ++i) {
+      // If a dimension is dynamic, call set-dimension-size on the output.
+      if (dynamic_dims[i]) {
+        auto dynamic_dim_size =
+            xla::Slice(context->Input(1), {i}, {i + 1}, {1});
+        dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
+        dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
+        buffer = xla::SetDimensionSize(buffer, dynamic_dim_size, i);
+      }
+    }
     auto result = XlaScatter(buffer, sparse_values, indices,
                              /*indices_are_vectors=*/indices_shape.dims() > 1,
                              /*combiner=*/{}, builder);
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index fe0cc9c..4a4dcb2 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -25,6 +25,8 @@
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -134,6 +136,11 @@
       // Pad input to 2x to avoid OOB access.
       slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
                        padding_config);
+      for (int64 i = 0; i < result_dims_are_dynamic.size(); ++i) {
+        if (result_dims_are_dynamic[i]) {
+          slice = xla::RemoveDynamicDimension(slice, i);
+        }
+      }
     }
     std::vector<xla::XlaOp> start_indices;
     std::vector<xla::XlaOp> slice_sizes_dynamic;
@@ -156,16 +163,30 @@
             xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
                                     input_xla_shape.dimensions(i));
       }
+
+      auto scalar_must_be_non_negative = [ctx](xla::XlaOp value) -> bool {
+        // Check if the lower-bound of a value is always >= 0
+        auto lower_bound = ctx->value_inference().AnalyzeConstant(
+            value, xla::ValueInferenceMode::kLowerBound);
+        if (!lower_bound.ok() || !lower_bound->AllValid()) {
+          // Can't infer a lower bound.
+          return false;
+        }
+        return lower_bound->Get<int32>({}) >= 0;
+      };
       if (begin_mask) {
         begin_index = zero;
       } else {
         begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
                                  {sparse_index + 1}, {1});
         begin_index = xla::Reshape(begin_index, {});
-        auto index_negative = xla::Lt(begin_index, zero);
-        auto wrapped_index = xla::Add(dim_size, begin_index);
-        // Wrap negative indices around.
-        begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+        if (!scalar_must_be_non_negative(begin_index)) {
+          // begin could be negative.
+          auto index_negative = xla::Lt(begin_index, zero);
+          auto wrapped_index = xla::Add(dim_size, begin_index);
+          // Wrap negative indices around.
+          begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+        }
       }
       start_indices.push_back(begin_index);
       if (end_mask) {
@@ -174,9 +195,12 @@
         end_index = xla::Slice(ctx->Input("end"), {sparse_index},
                                {sparse_index + 1}, {1});
         end_index = xla::Reshape(end_index, {});
-        auto index_negative = xla::Lt(end_index, zero);
-        auto wrapped_index = xla::Add(dim_size, end_index);
-        end_index = xla::Select(index_negative, wrapped_index, end_index);
+        if (!scalar_must_be_non_negative(end_index)) {
+          // end could be negative.
+          auto index_negative = xla::Lt(end_index, zero);
+          auto wrapped_index = xla::Add(dim_size, end_index);
+          end_index = xla::Select(index_negative, wrapped_index, end_index);
+        }
       }
       slice_sizes_dynamic.push_back(
           xla::Max(xla::Sub(end_index, begin_index), zero));
@@ -184,13 +208,24 @@
 
     slice =
         xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
-
-    for (int64_t i = 0; i < input_shape.dims(); ++i) {
-      if (result_dims_are_dynamic[i]) {
-        slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
+    // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove
+    // size 1 dims of a shape.
+    slice = xla::Reshape(slice, final_shape.dim_sizes());
+    for (int64_t i = 0; i < final_shape.dims(); ++i) {
+      int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i];
+      // If processing_shape_dim is -1, it means the output dimension was newly
+      // added by new_axis_mask_, which doesn't show up in input.
+      if (processing_shape_dim != -1 &&
+          result_dims_are_dynamic[processing_shape_dim]) {
+        // We gave a generous bound (same as input) to the output, try reset
+        // the bound if a tighter one can be found.
+        auto status = xla::SetDimensionSizeWithRebound(
+            &ctx->value_inference(), slice,
+            slice_sizes_dynamic[processing_shape_dim], i);
+        OP_REQUIRES_OK(ctx, status.status());
+        slice = status.ValueOrDie();
       }
     }
-    slice = xla::Reshape(slice, final_shape.dim_sizes());
     ctx->SetOutput(0, slice);
   }
 
@@ -483,7 +518,9 @@
     absl::InlinedVector<int64_t, 4> strides;
 
     TensorShape input_shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
+    OP_REQUIRES_OK(
+        ctx, ctx->ConstantInputAsShape(0, &input_shape,
+                                       xla::ValueInferenceMode::kUpperBound));
     xla::Literal begin_literal, end_literal, strides_literal;
 
     bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
@@ -565,14 +602,14 @@
 
     xla::XlaOp dynamic_shape = ctx->Input(0);
     xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie();
-    ctx->set_dynamic_dimension_is_minus_one(true);
-    std::vector<int64> dynamic_size;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size));
+    std::vector<bool> dynamic_input;
+    OP_REQUIRES_OK(ctx,
+                   ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_input));
     // Input of strided_slice_op has to have the same shape as output.
     DCHECK_EQ(grad_shape.rank(), input_shape.dims());
     for (int64_t dim = 0; dim < input_shape.dims(); ++dim) {
       DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim));
-      if (dynamic_size[dim] == -1) {
+      if (dynamic_input[dim]) {
         // Input is a dynamic dimension, set the same dynamic dimension size in
         // the output.
         auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1});
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
index e043fbc..ef06e9d 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -144,7 +144,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     int64_t num_elements;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntScalar(
+                       1, &num_elements, xla::ValueInferenceMode::kUpperBound));
     bool num_element_is_dynamic;
     OP_REQUIRES_OK(
         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
@@ -214,7 +216,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     int64_t max_num_elements;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
+    OP_REQUIRES_OK(
+        ctx, ctx->ConstantInputAsIntScalar(
+                 1, &max_num_elements, xla::ValueInferenceMode::kUpperBound));
     bool num_element_is_dynamic;
     OP_REQUIRES_OK(
         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index c3b9a51..8bd93eb 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -166,8 +166,14 @@
           "ResolveConstant called on XlaExpression: ", HumanString());
   }
   TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+  // The XLA layout is specified minor to major, and TensorFlow uses a major to
+  // minor order.
+  std::vector<int64_t> layout_indices(shape.dims());
+  std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+  xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
   if (mode == xla::ValueInferenceMode::kLowerBound ||
-      mode == xla::ValueInferenceMode::kUpperBound) {
+      mode == xla::ValueInferenceMode::kUpperBound ||
+      mode == xla::ValueInferenceMode::kValue) {
     std::vector<int64_t> layout_indices(shape.dims());
     std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
     xla::ValueInference value_inference(handle().builder());
@@ -177,8 +183,8 @@
       return {absl::nullopt};
     }
     Tensor tensor;
-    TF_RETURN_IF_ERROR(
-        LiteralToHostTensor(literal.GetValue().value(), dtype(), &tensor));
+    TF_RETURN_IF_ERROR(LiteralToHostTensor(
+        literal.GetValue().value().Relayout(layout), dtype(), &tensor));
     return {tensor};
   }
 
@@ -195,11 +201,6 @@
                       handle().builder()->BuildConstantSubGraph(
                           handle(), dynamic_dimension_is_minus_one));
 
-  // The XLA layout is specified minor to major, and TensorFlow uses a major to
-  // minor order.
-  std::vector<int64> layout_indices(shape.dims());
-  std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
-  xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
   TF_ASSIGN_OR_RETURN(xla::Literal literal,
                       client->ComputeConstant(constant_graph, &layout));
   Tensor tensor;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 1f533ad..4c2c066 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -32,7 +33,9 @@
 namespace tensorflow {
 
 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
-    : context_(context), dynamic_dimension_is_minus_one_(false) {}
+    : context_(context),
+      dynamic_dimension_is_minus_one_(false),
+      value_inference_(xla_context()->builder()) {}
 
 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
   return context_->ValidateInputsAreSameShape(op);
@@ -46,6 +49,10 @@
   return xla_context()->builder();
 }
 
+xla::ValueInference& XlaOpKernelContext::value_inference() {
+  return value_inference_;
+}
+
 XlaCompiler* XlaOpKernelContext::compiler() const {
   return xla_context()->compiler();
 }
@@ -154,6 +161,18 @@
   return start;
 }
 
+Status XlaOpKernelContext::ResolveInputDynamism(
+    int index, xla::Literal* dynamism_literal) {
+  return ResolveInputDynamismReshaped(
+      index, context_->input(index).shape().dim_sizes(), dynamism_literal);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamism(
+    absl::string_view name, xla::Literal* dynamism_literal) {
+  TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+  return ResolveInputDynamism(index, dynamism_literal);
+}
+
 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
                                          xla::Literal* constant_literal,
                                          xla::ValueInferenceMode mode) {
@@ -301,32 +320,56 @@
 }
 
 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
-    int index, std::vector<bool>* out) {
-  xla::Literal literal;
+    absl::string_view name, std::vector<bool>* out) {
+  TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+  return ResolveInputDynamismIntoPredVector(index, out);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
+                                                        bool* out) {
+  TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+  return ResolveInputDynamismIntoPred(index, out);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismReshaped(
+    int index, absl::Span<const int64_t> new_dims,
+    xla::Literal* dynamism_literal) {
   XlaExpression e = InputExpression(index);
   auto* client = compiler() ? compiler()->client() : nullptr;
   StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
   if (!dynamism_or_status.ok()) {
+    xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
     // When failed to resolve dynamism, conservatively consider the value
     // dynamic. This could happen if the input depends on some ops like
     // custom-call that is not supported generally for dynamism computation.
-    //
-    // TODO(b/176993339): Support resolving dynamism across computations so
-    // resolving dynamism will not fail in those cases.
-    out->resize(InputShape(index).num_elements(), true);
+    *dynamism_literal =
+        true_literal
+            .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
+            .ValueOrDie();
+
     return Status::OK();
   }
   Tensor dynamism = dynamism_or_status.ValueOrDie();
 
   Tensor temp(dynamism.dtype());
-  TensorShape tensor_shape({InputShape(index).num_elements()});
-  if (!temp.CopyFrom(dynamism, tensor_shape)) {
+  if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
     return errors::InvalidArgument(
         context_->op_kernel().name(), " input ", index, " has shape ",
-        dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
+        dynamism.shape().DebugString(),
+        " but was asked to be reshaped to incompatible shape ",
+        TensorShape(new_dims).DebugString());
   }
 
-  TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
+  TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
+  return Status::OK();
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
+    int index, std::vector<bool>* out) {
+  xla::Literal literal;
+  TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
+      index, {InputShape(index).num_elements()}, &literal));
+
   return LiteralToPredVector(literal, out);
 }
 
@@ -363,7 +406,7 @@
     absl::string_view name, std::vector<int64_t>* out,
     xla::ValueInferenceMode mode) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
-  return ConstantInputAsIntVector(index, out);
+  return ConstantInputAsIntVector(index, out, mode);
 }
 
 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 29d40ea..e713607 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -69,6 +69,8 @@
   // Returns the XLA XlaBuilder containing the output of compilation.
   xla::XlaBuilder* builder() const;
 
+  xla::ValueInference& value_inference();
+
   // Inputs
 
   // Returns the number of inputs to the operator.
@@ -127,6 +129,17 @@
   // predicates.
   Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
   Status ResolveInputDynamismIntoPred(int index, bool* out);
+  Status ResolveInputDynamismIntoPredVector(absl::string_view name,
+                                            std::vector<bool>* out);
+  Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out);
+
+  Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal);
+  Status ResolveInputDynamism(absl::string_view name,
+                              xla::Literal* dynamism_literal);
+
+  Status ResolveInputDynamismReshaped(int index,
+                                      absl::Span<const int64_t> new_dims,
+                                      xla::Literal* dynamism_literal);
   // Helper methods for constant inputs.
 
   // Evaluates input `index` and stores it in `*constant_literal`. If the
@@ -329,7 +342,6 @@
  private:
   // Returns the tensor of input `name`.
   const Tensor& GetInputTensorByName(absl::string_view name);
-
   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
   // InputShape(index), and stores it in `*constant_literal`. If the input
   // cannot be evaluated, e.g., because it depends on unbound parameters,
@@ -342,6 +354,7 @@
 
   OpKernelContext* const context_;
   bool dynamic_dimension_is_minus_one_;
+  xla::ValueInference value_inference_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 6f70a13..e73dd0f 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -142,6 +142,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:value_inference",
         "//tensorflow/compiler/xla/client:xla_builder",
     ],
 )
diff --git a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
index c4abf9e..e9541ec 100644
--- a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
+++ b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
@@ -109,4 +109,22 @@
                             false_operand, false_computation_rewritten);
   });
 }
+
+StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
+                                            XlaOp operand, XlaOp dimension_size,
+                                            int64_t dimension) {
+  auto inferred_bound_status_or = value_inference->AnalyzeConstant(
+      dimension_size, xla::ValueInferenceMode::kUpperBound);
+  TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
+  if (inferred_bound_status_or->AllValid()) {
+    int64_t inferred_bound = inferred_bound_status_or->Get<int32>({}).value();
+    TF_ASSIGN_OR_RETURN(auto* shape_ptr,
+                        operand.builder()->GetShapePtr(operand));
+    // Found a tighter bound, do a slice.
+    if (shape_ptr->dimensions(dimension) > inferred_bound)
+      operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
+  }
+  operand = xla::SetDimensionSize(operand, dimension_size, dimension);
+  return operand;
+}
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
index 4505376..cf318aa 100644
--- a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
+++ b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_
 
 #include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -32,6 +33,12 @@
                          const XlaComputation& true_computation,
                          XlaOp false_operand,
                          const XlaComputation& false_computation);
+
+// Similar to SetDimensionSize, but automatically adjust the bound of output if
+// a tighter one can be inferred by `value_inference`.
+StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
+                                            XlaOp operand, XlaOp dimension_size,
+                                            int64_t dimension);
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index d91155f..0f26284 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -91,7 +91,7 @@
   if (element_type == TOKEN) {
     return LiteralUtil::CreateToken();
   }
-  Literal literal = LiteralUtil::Zero(element_type);
+  Literal literal = LiteralUtil::One(element_type);
   return literal.Broadcast(reference_shape, {}).ValueOrDie();
 }
 
@@ -374,7 +374,34 @@
     return false;
   }
 
-  absl::flat_hash_map<int64_t, Literal> evaluated;
+  struct CacheKey {
+    CacheKey(int64_t handle, InferenceContext context,
+             PostorderDFSNodeType type)
+        : handle(handle), context(context), type(type) {}
+    int64_t handle;
+    InferenceContext context;
+    PostorderDFSNodeType type;
+
+    template <typename H>
+    friend H AbslHashValue(H h, const CacheKey& key) {
+      h = H::combine(std::move(h), key.handle);
+      h = H::combine(std::move(h), key.context.shape_index.ToString());
+      h = H::combine(std::move(h),
+                     VectorString(key.context.caller_operand_handles));
+      h = H::combine(std::move(h), key.type);
+      return h;
+    }
+
+    friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
+      return lhs.handle == rhs.handle &&
+             lhs.context.shape_index == rhs.context.shape_index &&
+             lhs.context.caller_operand_handles ==
+                 rhs.context.caller_operand_handles &&
+             lhs.type == rhs.type;
+    }
+  };
+
+  absl::flat_hash_map<CacheKey, Literal> evaluated;
   HandleToInstruction handle_to_instruction;
   HandleToComputation handle_to_computation;
 };
@@ -405,6 +432,8 @@
     case HloOpcode::kWhile:
     case HloOpcode::kSend:
     case HloOpcode::kRecv:
+    case HloOpcode::kSendDone:
+    case HloOpcode::kRecvDone:
     case HloOpcode::kParameter: {
       if (opcode == HloOpcode::kParameter &&
           !context.caller_operand_handles.empty()) {
@@ -665,6 +694,11 @@
             return Literal::CreateFromProto(root->literal());
           }
         });
+      } else if (root->custom_call_target() == "Sharding") {
+        return PostorderDFSNode()
+            .AddDependency(root->operand_ids(0),
+                           PostorderDFSNodeType::kConstantUpperBound, context)
+            .AddVisit([](Literal operand) { return operand; });
       }
       return InvalidArgument(
           "Upper-bound inferencing on custom call %s is not supported",
@@ -793,6 +827,11 @@
                            PostorderDFSNodeType::kConstantValue, context)
             .AddVisit(
                 [](Literal operand) -> StatusOr<Literal> { return operand; });
+      } else if (root->custom_call_target() == "Sharding") {
+        return PostorderDFSNode()
+            .AddDependency(root->operand_ids(0),
+                           PostorderDFSNodeType::kConstantValue, context)
+            .AddVisit([](Literal operand) { return operand; });
       } else {
         return PostorderDFSNode().AddVisit(
             [root, context](absl::Span<Literal>) {
@@ -805,9 +844,11 @@
     }
     case HloOpcode::kSort: {
       PostorderDFSNode result;
+      InferenceContext dep_context = context;
+      dep_context.shape_index = {};
       for (auto operand_id : root->operand_ids()) {
         result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
-                             context);
+                             dep_context);
       }
       const HloComputationProto* computation_proto =
           handle_to_computation(root->called_computation_ids(0));
@@ -908,17 +949,24 @@
       });
     }
     case HloOpcode::kSetDimensionSize:
-      return result.AddVisit([root, type]() {
+      return result.AddVisit([root, type](absl::Span<Literal> operands) {
+        bool any_dynamic_operand = absl::c_any_of(
+            operands, [](Literal& operand) { return !operand.IsAll(0); });
         // If values in a tensor `t` with bound are [e0, e1, e2...], we can say
         // the max value of each position is [max(t), max(t), max(t), ...]. The
         // effective size of this tensor doesn't change the max value.
         return CreatePredLiteral(
-            type == PostorderDFSNodeType::kValueIsDynamic,
+            type == PostorderDFSNodeType::kValueIsDynamic &&
+                any_dynamic_operand,
             ShapeUtil::MakeStaticShape(Shape(root->shape())));
       });
     case HloOpcode::kDynamicSlice: {
-      return result.AddVisit(
-          [root]() { return CreatePredLiteral(true, Shape(root->shape())); });
+      return result.AddVisit([root](absl::Span<Literal> operands) {
+        // If any of the operand is dynamic, we say output is dynamic.
+        bool any_dynamic_operand = absl::c_any_of(
+            operands, [](Literal& operand) { return !operand.IsAll(0); });
+        return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
+      });
     }
     case HloOpcode::kAbs:
     case HloOpcode::kRoundNearestAfz:
@@ -1182,7 +1230,6 @@
             Literal lhs = std::move(operands[2]);
             Literal rhs = std::move(operands[3]);
             auto result = CreatePredLiteral(true, Shape(root->shape()));
-
             result.MutableEachCell<bool>(
                 [&](absl::Span<const int64_t> indices, bool value) {
                   absl::optional<bool> optional_selector =
@@ -1248,6 +1295,8 @@
             }
           }
         });
+      } else if (root->custom_call_target() == "Sharding") {
+        return result.AddVisit([](Literal operand) { return operand; });
       } else {
         return InvalidArgument(
             "Dynamic inferencing on custom call %s is not supported",
@@ -1258,6 +1307,10 @@
     }
 
     case HloOpcode::kCall:
+    case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kWhile: {
       return PostorderDFSNode().AddVisit([root,
                                           context]() -> StatusOr<Literal> {
@@ -1268,8 +1321,12 @@
       break;
     }
     default:
-      return Unimplemented("Can't infer dynamism through %s: %s",
-                           root->opcode(), root->DebugString());
+      return PostorderDFSNode().AddVisit([root,
+                                          context]() -> StatusOr<Literal> {
+        return CreatePredLiteral(
+            true,
+            ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
+      });
   }
 }
 
@@ -1296,8 +1353,10 @@
     VisitState state;
     Visit visit;  // The handler to call once the dependencies are resolved into
                   // literal form.
-    int64_t id;   // Unique id in the work queue, starting from 0.
-    std::vector<int64_t> dependencies;
+    int64_t id;     // Unique id in the work queue, starting from 0.
+    std::vector<CacheKey> dependencies;
+
+    CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
   };
 
   std::vector<WorkItem> stack;
@@ -1315,21 +1374,25 @@
 
       // Gather dependencies and transform them into literals.
       std::vector<Literal> literals;
-      for (int64_t dep_id : item.dependencies) {
-        TF_RET_CHECK(evaluated.contains(dep_id));
-        literals.emplace_back(evaluated.at(dep_id).Clone());
+      for (CacheKey& dep_key : item.dependencies) {
+        TF_RET_CHECK(evaluated.contains(dep_key));
+        literals.emplace_back(evaluated.at(dep_key).Clone());
       }
       VLOG(1) << "Start visiting with dependency type: "
               << PostorderDFSNodeTypeToString(item.type);
       TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
       VLOG(1) << "End visiting: " << literal.ToString();
-      evaluated[item.id] = std::move(literal);
+      evaluated[item.GetCacheKey()] = std::move(literal);
       stack.pop_back();
       continue;
     }
     // This is the first time we see this node, we want to gather its
     // dependenceis.
     VLOG(1) << "unvisited";
+    if (evaluated.contains(item.GetCacheKey())) {
+      stack.pop_back();
+      continue;
+    }
     item.state = kVisiting;
     PostorderDFSNode node;
     switch (item.type) {
@@ -1360,22 +1423,21 @@
     // resolved.
     item.visit = node.visit;
 
-    // Dependencies of this item have id in the range of [unique_id, unique_id +
-    // dependencies.size())
-    for (int64_t i = 0; i < node.dependencies.size(); ++i) {
-      item.dependencies.push_back(unique_id + i);
-    }
+    const int64_t current_item_id = stack.size() - 1;
     // Enqueue dependencies into the stack. `item` shouldn't be accessed after
     // this point.
     for (const PostorderDFSDep& dep : node.dependencies) {
-      VLOG(1) << "dep " << dep.annotation << ":"
-              << handle_to_instruction(dep.handle)->DebugString();
+      VLOG(1) << "dep " << dep.annotation
+              << "::" << handle_to_instruction(dep.handle)->DebugString()
+              << "index" << dep.context.shape_index
+              << " stack size:" << stack.size();
       stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
                          unique_id++);
+      stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
     }
   }
-  VLOG(1) << "done" << evaluated[root.id].ToString();
-  return evaluated[root.id].Clone();
+  VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
+  return evaluated[root.GetCacheKey()].Clone();
 }
 
 StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
diff --git a/tensorflow/compiler/xla/tests/value_inference_test.cc b/tensorflow/compiler/xla/tests/value_inference_test.cc
index 9850b63..55bd8b0 100644
--- a/tensorflow/compiler/xla/tests/value_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/value_inference_test.cc
@@ -259,6 +259,16 @@
   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
 }
 
+TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
+  XlaBuilder b(TestName());
+
+  auto constant = ConstantR1<int32>(&b, {0, 1, 2, 3});
+  auto slice_start = ConstantR0(&b, 1);
+  auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
+  EXPECT_FALSE(
+      ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
+}
+
 TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
   XlaBuilder b(TestName());
   // Test the analysis on a gather where first operand and second operand have
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
index 1c9deaf..fa9dbb1 100644
--- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -48,7 +48,9 @@
     TensorShape indices_shape = ctx->InputShape(1);
 
     int64_t num_segments;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ConstantInputAsIntScalar(
+                       2, &num_segments, xla::ValueInferenceMode::kUpperBound));
 
     OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
                 errors::InvalidArgument(