Start using value inference in constant value inference.
- Change rest of the kernels that support dynamic shapes to use value inference.
- Update pad_op so that it can infer a tigher bound using value inference.
- Update slice_op so it doesn't create a shape with all dynamic dimensions when some slice sizes are constant.
PiperOrigin-RevId: 391081565
Change-Id: Iccfbdd8b899bea8dbc42980c27b6c995233b3ccd
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index f62323f..08c8e21 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -182,6 +182,7 @@
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:dynamic_shaped_ops",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index 7ada23a..4c6cb3b 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -20,6 +20,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -46,7 +47,8 @@
errors::InvalidArgument("In[", i, "] must be a vector.",
in_shape.DebugString()));
std::vector<int64_t> shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &shape));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(
+ i, &shape, xla::ValueInferenceMode::kUpperBound));
shapes.push_back(BCast::Vec(shape.begin(), shape.end()));
}
BCast bcast(shapes[0], shapes[1]);
@@ -95,7 +97,11 @@
errors::InvalidArgument("In[", i, "] must be a vector.",
in_shape.DebugString()));
std::vector<int64_t> vec;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec));
+ // Technically we don't need to infer the upper-bound here. However the
+ // forward path uses the upperbound as bounded shape so we need backward
+ // path to use the same shape to decide the reduction indices.
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(
+ i, &vec, xla::ValueInferenceMode::kUpperBound));
shapes.push_back(BCast::Vec(vec.begin(), vec.end()));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index f872361..8847692 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -30,7 +30,9 @@
void Compile(XlaOpKernelContext* context) override {
TensorShape output_shape;
- OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+ OP_REQUIRES_OK(context,
+ context->ConstantInputAsShape(
+ 1, &output_shape, xla::ValueInferenceMode::kUpperBound));
auto output_status_or =
BroadcastTo(context->Input(0), output_shape.dim_sizes());
OP_REQUIRES_OK(context, output_status_or.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 51e6b89..e2e0fc8 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -45,7 +45,9 @@
const xla::XlaOp& logits = ctx->Input(0);
TensorShape logits_shape = ctx->InputShape(0);
int64_t num_samples;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples));
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntScalar(
+ 1, &num_samples, xla::ValueInferenceMode::kUpperBound));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
errors::InvalidArgument("logits should be a matrix, got shape ",
logits_shape.DebugString()));
@@ -66,7 +68,10 @@
xla::Shape uniform_shape;
int class_dimension;
- if (num_samples != 1) {
+ bool num_samples_is_dynamic = false;
+ OP_REQUIRES_OK(
+ ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
+ if (num_samples != 1 || num_samples_is_dynamic) {
std::array<int64_t, 3> uniform_shape_array = {
{batch_size, num_samples, num_classes}};
xla::PrimitiveType uniform_xla_type;
@@ -91,11 +96,9 @@
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
- bool num_samples_is_dynamic = false;
- OP_REQUIRES_OK(
- ctx, ctx->ResolveInputDynamismIntoPred(1, &num_samples_is_dynamic));
- if (num_samples_is_dynamic && num_samples != 1) {
- // Number samples is dimension 1 in uniform_shape_array.
+
+ if (num_samples_is_dynamic) {
+ // num_samples is dimension 1 in uniform_shape_array.
log_uniforms = xla::SetDimensionSize(log_uniforms, ctx->Input(1), 1);
}
@@ -119,7 +122,7 @@
/*axis=*/class_dimension, /*stable=*/true);
}
- if (num_samples == 1) {
+ if (num_samples == 1 && !num_samples_is_dynamic) {
argmax = xla::Reshape(argmax, {batch_size, 1});
}
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index fbf403b..1cb67d8 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -155,7 +155,9 @@
const int32_t N = ctx->num_inputs() - 1;
const TensorShape inp0_shape = ctx->InputShape(1);
std::vector<int64_t> inp0_dims;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims));
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntVector(
+ 1, &inp0_dims, xla::ValueInferenceMode::kUpperBound));
const int64_t inp0_rank = inp0_shape.num_elements();
int64_t cdim;
@@ -174,7 +176,9 @@
inp0_rank, " elements, but got ",
inp_shape.num_elements()));
std::vector<int64_t> inp_dims;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsIntVector(
+ 1 + i, &inp_dims, xla::ValueInferenceMode::kUpperBound));
Tensor out_constant(DT_INT32, TensorShape({inp0_rank}));
auto out_vec = out_constant.vec<int32>();
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 53bbb86..6ebb6b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -104,7 +104,9 @@
void Compile(XlaOpKernelContext* ctx) override {
TensorShape input_tensor_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape,
+ xla::ValueInferenceMode::kUpperBound));
xla::Shape input_shape =
TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2,
@@ -170,7 +172,9 @@
void Compile(XlaOpKernelContext* ctx) override {
TensorShape filter_tensor_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape,
+ xla::ValueInferenceMode::kUpperBound));
xla::Shape filter_shape =
TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index 1c9d81d2..ebcbadb 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -19,6 +19,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -45,17 +46,17 @@
value_shape.DebugString()));
std::vector<int64_t> dims;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims));
- // Set dynamic dimension value to -1 so that we know which dimension is
- // dynamic.
- ctx->set_dynamic_dimension_is_minus_one(true);
- std::vector<int64> dynamic_dims;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dynamic_dims));
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntVector(
+ "dims", &dims, xla::ValueInferenceMode::kUpperBound));
+ std::vector<bool> dynamic_dims;
+ OP_REQUIRES_OK(
+ ctx, ctx->ResolveInputDynamismIntoPredVector("dims", &dynamic_dims));
auto output = xla::Broadcast(ctx->Input("value"), dims);
for (int64_t i = 0; i < dims.size(); ++i) {
// If a dimension is dynamic, call set-dimension-size on the output.
- if (dynamic_dims[i] == -1) {
+ if (dynamic_dims[i]) {
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
index c00cdf5..dcd2c07 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc
@@ -38,16 +38,21 @@
VLOG(1) << "Trying to resolve constant " << i;
// NOTE: We can not simply check that this is Kind::kConstant because
// this could be the output of a MetadataOnly op e.g. Size.
+
+ // If we can infer the constant values of an inner computation's argument,
+ // replace them with constants. If that fails, we fallback to infer the
+ // bounds of the argument.
StatusOr<absl::optional<Tensor>> maybe_constant =
expression.ResolveConstant(ctx->compiler()->client());
- if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) {
+ StatusOr<absl::optional<Tensor>> bounds =
+ expression.ResolveConstant(ctx->compiler()->client(), false,
+ xla::ValueInferenceMode::kUpperBound);
+ if ((maybe_constant.ok() && maybe_constant->has_value()) ||
+ (bounds.ok() && bounds->has_value())) {
StatusOr<Tensor> values_are_dynamic =
expression.ResolveDynamism(ctx->compiler()->client());
bool all_values_are_static = false;
- if (!values_are_dynamic.ok()) {
- // Conservatiely assume all values are dynamic.
- all_values_are_static = true;
- } else {
+ if (values_are_dynamic.ok()) {
xla::Literal literal =
HostTensorToLiteral(values_are_dynamic.ValueOrDie()).ValueOrDie();
all_values_are_static = literal.IsAll(0);
@@ -60,8 +65,7 @@
arg->shape = expression.GetShape().ValueOrDie();
resolved_constant_idxs.push_back(i);
} else {
- arg->value_bound.emplace(
- std::move(maybe_constant.ValueOrDie().value()));
+ arg->value_bound.emplace(std::move(bounds.ValueOrDie().value()));
arg->value_dynamism.emplace(
std::move(values_are_dynamic.ValueOrDie()));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 792a3cb..09119a7 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -17,7 +17,9 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -52,8 +54,13 @@
}
xla::Literal pad_literal;
- OP_REQUIRES_OK(ctx,
- ctx->ConstantInputAsInt64Literal("paddings", &pad_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(
+ "paddings", &pad_literal,
+ xla::ValueInferenceMode::kUpperBound));
+
+ xla::Literal padding_dynamism_literal;
+ OP_REQUIRES_OK(
+ ctx, ctx->ResolveInputDynamism("paddings", &padding_dynamism_literal));
xla::PaddingConfig config;
for (int i = 0; i < dims; ++i) {
@@ -69,15 +76,60 @@
// PadV2 added a "constant_values" input that indicates the pad value.
xla::XlaOp constant_values;
+ xla::XlaOp pad;
if (ctx->num_inputs() == 3) {
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")),
errors::InvalidArgument("constant_values must be a scalar."));
- ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config));
+ pad = xla::Pad(input, ctx->Input("constant_values"), config);
} else {
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, xla::Pad(input, zero, config));
+ pad = xla::Pad(input, zero, config);
}
+
+ for (int i = 0; i < dims; ++i) {
+ bool low_pad_is_dynamic = padding_dynamism_literal.Get<bool>({i, 0});
+
+ OP_REQUIRES(
+ ctx, !low_pad_is_dynamic,
+ errors::InvalidArgument("low_pad in Pad op has to be static."));
+ bool high_pad_is_dynamic = padding_dynamism_literal.Get<bool>({i, 1});
+ if (high_pad_is_dynamic) {
+ // When we have
+ // pad_width = MAX_WIDTH - size(t)
+ // op = pad(t, /*high_pad=*/pad_width)
+ // The bound of the result size should be MAX_WIDTH, instead of
+ // `bound(t) + bound(pad_width)`
+ //
+ // We do this by analyzing the expression
+ // size(op) = size(t) + MAX_WIDTH - size(t)
+ // and leave value inference to analyze it.
+ xla::XlaOp high_pad_size =
+ xla::Slice(ctx->Input("paddings"), {i, 1}, {i + 1, 2}, {1, 1});
+ high_pad_size = xla::Reshape(high_pad_size, {});
+ high_pad_size = xla::ConvertElementType(high_pad_size, xla::S32);
+ // Low pad has to be static.
+ xla::XlaOp low_pad_size = xla::ConstantR0<int32>(
+ ctx->builder(), pad_literal.Get<int64>({i, 0}));
+ xla::XlaOp input_size = xla::GetDimensionSize(input, i);
+ xla::XlaOp total_size = low_pad_size + input_size + high_pad_size;
+ auto size_upper_bound_status_or =
+ ctx->value_inference().AnalyzeConstant(
+ total_size, xla::ValueInferenceMode::kUpperBound);
+ OP_REQUIRES_OK(ctx, size_upper_bound_status_or.status());
+ auto size_upper_bound =
+ size_upper_bound_status_or.ValueOrDie().Get<int32>({});
+ OP_REQUIRES(
+ ctx, size_upper_bound.has_value(),
+ errors::InvalidArgument(
+ "Failed to infer upperbound of total size after padding."));
+ // If we know a tighter upperbound, trim the output with the new
+ // upperbound.
+ pad = xla::SliceInDim(pad, 0, size_upper_bound.value(), 1, i);
+ pad = xla::SetDimensionSize(pad, total_size, i);
+ }
+ }
+ ctx->SetOutput(0, pad);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 3997e1d..546b205 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -30,6 +30,7 @@
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -44,7 +45,8 @@
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(
+ 0, &shape, xla::ValueInferenceMode::kUpperBound));
const DataType dtype = output_type(0);
xla::Shape xla_shape;
@@ -58,7 +60,18 @@
<< name();
xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
XlaHelpers::One(b, dtype), xla_shape);
-
+ std::vector<bool> dynamic_dims;
+ OP_REQUIRES_OK(ctx,
+ ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_dims));
+ for (int64_t i = 0; i < xla_shape.rank(); ++i) {
+ // If a dimension is dynamic, call set-dimension-size on the output.
+ if (dynamic_dims[i]) {
+ auto dynamic_dim_size = xla::Slice(ctx->Input(0), {i}, {i + 1}, {1});
+ dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
+ dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
+ result = xla::SetDimensionSize(result, dynamic_dim_size, i);
+ }
+ }
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index fca4938..41f2c2e 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
@@ -57,8 +58,9 @@
TensorShape indices_shape = ctx->InputShape(1);
int64_t num_segments;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
-
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntScalar(
+ 2, &num_segments, xla::ValueInferenceMode::kUpperBound));
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
errors::InvalidArgument(type_string(),
" requires that indices' rank be"
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 3d7a94e..8dc9b61 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -21,6 +21,8 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -57,10 +59,11 @@
std::vector<int64_t> begin;
std::vector<int64_t> size;
- const bool begin_is_constant =
+ const bool all_begins_are_constant =
ctx->ConstantInputAsIntVector(1, &begin).ok();
- const bool size_is_constant = ctx->ConstantInputAsIntVector(2, &size).ok();
- if (begin_is_constant && size_is_constant) {
+ const bool all_sizes_are_constant =
+ ctx->ConstantInputAsIntVector(2, &size).ok();
+ if (all_begins_are_constant && all_sizes_are_constant) {
std::vector<int64_t> wrapped_size(size.size());
// `begin` is a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
@@ -101,12 +104,12 @@
std::vector<int64_t> strides(begin.size(), 1);
auto slice = xla::Slice(ctx->Input(0), begin, limits, strides);
// Check for slice on dynamic dimensions.
- ctx->set_dynamic_dimension_is_minus_one(true);
- std::vector<int64> dynamic_size;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &dynamic_size));
+ std::vector<bool> size_is_dynamic;
+ OP_REQUIRES_OK(
+ ctx, ctx->ResolveInputDynamismIntoPredVector(2, &size_is_dynamic));
for (int64_t i = 0; i < size.size(); ++i) {
- if (dynamic_size[i] == -1) {
+ if (size_is_dynamic[i]) {
if (size[i] != -1) {
// If there is a dynamic dimension, properly set dimension size of
// the slice.
@@ -124,7 +127,7 @@
// This essentially makes size as dynamic.
bool constant_size_is_minus_one = false;
// `begin` or `size` is not a compile-time constant.
- if (size_is_constant) {
+ if (all_sizes_are_constant) {
for (int i = 0; i < input_dims; ++i) {
if (size[i] < 0) {
OP_REQUIRES(ctx, size[i] == -1,
@@ -147,9 +150,9 @@
begin_indices.push_back(
xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {}));
}
- if (size_is_constant && !constant_size_is_minus_one) {
- ctx->SetOutput(0,
- xla::DynamicSlice(ctx->Input(0), begin_indices, size));
+ if (all_sizes_are_constant && !constant_size_is_minus_one) {
+ xla::XlaOp input = ctx->Input(0);
+ ctx->SetOutput(0, xla::DynamicSlice(input, begin_indices, size));
} else {
// Size is not constant, use input size as upperbound and then set
// dimension size on it.
@@ -157,16 +160,17 @@
// First pad input with input size to avoid OOB -- dynamic slice with
// OOB slice produces undesired results.
xla::PaddingConfig padding_config;
+ xla::XlaOp input = ctx->Input(0);
for (int64_t i = 0; i < input_dims; ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(0);
dims->set_edge_padding_high(input_shape.dim_size(i));
dims->set_interior_padding(0);
+ input = xla::RemoveDynamicDimension(input, i);
}
- auto padded_input = xla::Pad(
- ctx->Input(0), xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
- padding_config);
-
+ auto padded_input =
+ xla::Pad(input, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
+ padding_config);
// Slice full size out of the input starting from the offsets.
auto sliced = xla::DynamicSlice(padded_input, begin_indices,
input_shape.dim_sizes());
@@ -179,7 +183,23 @@
input_shape.dim_size(i)) -
begin_indices[i];
}
- sliced = xla::SetDimensionSize(sliced, dynamic_size, i);
+ auto constant_size = ctx->value_inference().AnalyzeConstant(
+ dynamic_size, xla::ValueInferenceMode::kValue);
+ OP_REQUIRES_OK(ctx, constant_size.status());
+ if (constant_size->AllValid()) {
+ // Slice size on this dimension is constant. This branch is
+ // triggered when some dimensions's slice sizes are constant while
+ // some are dynamic.
+ sliced = xla::SliceInDim(
+ sliced, 0, constant_size->Get<int32>({}).value(), 1, i);
+ } else {
+ // We gave a generous bound (same as input) to the output, try reset
+ // the bound if a tighter one can be found.
+ auto status = xla::SetDimensionSizeWithRebound(
+ &ctx->value_inference(), sliced, dynamic_size, i);
+ OP_REQUIRES_OK(ctx, status.status());
+ sliced = status.ValueOrDie();
+ }
}
ctx->SetOutput(0, sliced);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
index e6fe82f..bb5dfa5 100644
--- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
@@ -41,7 +41,9 @@
// output_shape
TensorShape output_shape;
- OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+ OP_REQUIRES_OK(context,
+ context->ConstantInputAsShape(
+ 1, &output_shape, xla::ValueInferenceMode::kUpperBound));
OP_REQUIRES(context, output_shape.dims() == num_dims,
errors::InvalidArgument(
"output_shape has incorrect number of elements: ",
@@ -72,7 +74,20 @@
}
xla::XlaBuilder* builder = context->builder();
auto buffer = Broadcast(default_value, output_shape.dim_sizes());
+ std::vector<bool> dynamic_dims;
+ OP_REQUIRES_OK(
+ context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
+ for (int64_t i = 0; i < dynamic_dims.size(); ++i) {
+ // If a dimension is dynamic, call set-dimension-size on the output.
+ if (dynamic_dims[i]) {
+ auto dynamic_dim_size =
+ xla::Slice(context->Input(1), {i}, {i + 1}, {1});
+ dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
+ dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
+ buffer = xla::SetDimensionSize(buffer, dynamic_dim_size, i);
+ }
+ }
auto result = XlaScatter(buffer, sparse_values, indices,
/*indices_are_vectors=*/indices_shape.dims() > 1,
/*combiner=*/{}, builder);
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index fe0cc9c..4a4dcb2 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -25,6 +25,8 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
@@ -134,6 +136,11 @@
// Pad input to 2x to avoid OOB access.
slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
padding_config);
+ for (int64 i = 0; i < result_dims_are_dynamic.size(); ++i) {
+ if (result_dims_are_dynamic[i]) {
+ slice = xla::RemoveDynamicDimension(slice, i);
+ }
+ }
}
std::vector<xla::XlaOp> start_indices;
std::vector<xla::XlaOp> slice_sizes_dynamic;
@@ -156,16 +163,30 @@
xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
input_xla_shape.dimensions(i));
}
+
+ auto scalar_must_be_non_negative = [ctx](xla::XlaOp value) -> bool {
+ // Check if the lower-bound of a value is always >= 0
+ auto lower_bound = ctx->value_inference().AnalyzeConstant(
+ value, xla::ValueInferenceMode::kLowerBound);
+ if (!lower_bound.ok() || !lower_bound->AllValid()) {
+ // Can't infer a lower bound.
+ return false;
+ }
+ return lower_bound->Get<int32>({}) >= 0;
+ };
if (begin_mask) {
begin_index = zero;
} else {
begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
{sparse_index + 1}, {1});
begin_index = xla::Reshape(begin_index, {});
- auto index_negative = xla::Lt(begin_index, zero);
- auto wrapped_index = xla::Add(dim_size, begin_index);
- // Wrap negative indices around.
- begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+ if (!scalar_must_be_non_negative(begin_index)) {
+ // begin could be negative.
+ auto index_negative = xla::Lt(begin_index, zero);
+ auto wrapped_index = xla::Add(dim_size, begin_index);
+ // Wrap negative indices around.
+ begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+ }
}
start_indices.push_back(begin_index);
if (end_mask) {
@@ -174,9 +195,12 @@
end_index = xla::Slice(ctx->Input("end"), {sparse_index},
{sparse_index + 1}, {1});
end_index = xla::Reshape(end_index, {});
- auto index_negative = xla::Lt(end_index, zero);
- auto wrapped_index = xla::Add(dim_size, end_index);
- end_index = xla::Select(index_negative, wrapped_index, end_index);
+ if (!scalar_must_be_non_negative(end_index)) {
+ // end could be negative.
+ auto index_negative = xla::Lt(end_index, zero);
+ auto wrapped_index = xla::Add(dim_size, end_index);
+ end_index = xla::Select(index_negative, wrapped_index, end_index);
+ }
}
slice_sizes_dynamic.push_back(
xla::Max(xla::Sub(end_index, begin_index), zero));
@@ -184,13 +208,24 @@
slice =
xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
-
- for (int64_t i = 0; i < input_shape.dims(); ++i) {
- if (result_dims_are_dynamic[i]) {
- slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
+ // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove
+ // size 1 dims of a shape.
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
+ for (int64_t i = 0; i < final_shape.dims(); ++i) {
+ int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i];
+ // If processing_shape_dim is -1, it means the output dimension was newly
+ // added by new_axis_mask_, which doesn't show up in input.
+ if (processing_shape_dim != -1 &&
+ result_dims_are_dynamic[processing_shape_dim]) {
+ // We gave a generous bound (same as input) to the output, try reset
+ // the bound if a tighter one can be found.
+ auto status = xla::SetDimensionSizeWithRebound(
+ &ctx->value_inference(), slice,
+ slice_sizes_dynamic[processing_shape_dim], i);
+ OP_REQUIRES_OK(ctx, status.status());
+ slice = status.ValueOrDie();
}
}
- slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
@@ -483,7 +518,9 @@
absl::InlinedVector<int64_t, 4> strides;
TensorShape input_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsShape(0, &input_shape,
+ xla::ValueInferenceMode::kUpperBound));
xla::Literal begin_literal, end_literal, strides_literal;
bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
@@ -565,14 +602,14 @@
xla::XlaOp dynamic_shape = ctx->Input(0);
xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie();
- ctx->set_dynamic_dimension_is_minus_one(true);
- std::vector<int64> dynamic_size;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size));
+ std::vector<bool> dynamic_input;
+ OP_REQUIRES_OK(ctx,
+ ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_input));
// Input of strided_slice_op has to have the same shape as output.
DCHECK_EQ(grad_shape.rank(), input_shape.dims());
for (int64_t dim = 0; dim < input_shape.dims(); ++dim) {
DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim));
- if (dynamic_size[dim] == -1) {
+ if (dynamic_input[dim]) {
// Input is a dynamic dimension, set the same dynamic dimension size in
// the output.
auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1});
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
index e043fbc..ef06e9d 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -144,7 +144,9 @@
void Compile(XlaOpKernelContext* ctx) override {
int64_t num_elements;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntScalar(
+ 1, &num_elements, xla::ValueInferenceMode::kUpperBound));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
@@ -214,7 +216,9 @@
void Compile(XlaOpKernelContext* ctx) override {
int64_t max_num_elements;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsIntScalar(
+ 1, &max_num_elements, xla::ValueInferenceMode::kUpperBound));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index c3b9a51..8bd93eb 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -166,8 +166,14 @@
"ResolveConstant called on XlaExpression: ", HumanString());
}
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+ // The XLA layout is specified minor to major, and TensorFlow uses a major to
+ // minor order.
+ std::vector<int64_t> layout_indices(shape.dims());
+ std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+ xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
if (mode == xla::ValueInferenceMode::kLowerBound ||
- mode == xla::ValueInferenceMode::kUpperBound) {
+ mode == xla::ValueInferenceMode::kUpperBound ||
+ mode == xla::ValueInferenceMode::kValue) {
std::vector<int64_t> layout_indices(shape.dims());
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::ValueInference value_inference(handle().builder());
@@ -177,8 +183,8 @@
return {absl::nullopt};
}
Tensor tensor;
- TF_RETURN_IF_ERROR(
- LiteralToHostTensor(literal.GetValue().value(), dtype(), &tensor));
+ TF_RETURN_IF_ERROR(LiteralToHostTensor(
+ literal.GetValue().value().Relayout(layout), dtype(), &tensor));
return {tensor};
}
@@ -195,11 +201,6 @@
handle().builder()->BuildConstantSubGraph(
handle(), dynamic_dimension_is_minus_one));
- // The XLA layout is specified minor to major, and TensorFlow uses a major to
- // minor order.
- std::vector<int64> layout_indices(shape.dims());
- std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
- xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
TF_ASSIGN_OR_RETURN(xla::Literal literal,
client->ComputeConstant(constant_graph, &layout));
Tensor tensor;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 1f533ad..4c2c066 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -23,6 +23,7 @@
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -32,7 +33,9 @@
namespace tensorflow {
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
- : context_(context), dynamic_dimension_is_minus_one_(false) {}
+ : context_(context),
+ dynamic_dimension_is_minus_one_(false),
+ value_inference_(xla_context()->builder()) {}
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
return context_->ValidateInputsAreSameShape(op);
@@ -46,6 +49,10 @@
return xla_context()->builder();
}
+xla::ValueInference& XlaOpKernelContext::value_inference() {
+ return value_inference_;
+}
+
XlaCompiler* XlaOpKernelContext::compiler() const {
return xla_context()->compiler();
}
@@ -154,6 +161,18 @@
return start;
}
+Status XlaOpKernelContext::ResolveInputDynamism(
+ int index, xla::Literal* dynamism_literal) {
+ return ResolveInputDynamismReshaped(
+ index, context_->input(index).shape().dim_sizes(), dynamism_literal);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamism(
+ absl::string_view name, xla::Literal* dynamism_literal) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ResolveInputDynamism(index, dynamism_literal);
+}
+
Status XlaOpKernelContext::ConstantInput(absl::string_view name,
xla::Literal* constant_literal,
xla::ValueInferenceMode mode) {
@@ -301,32 +320,56 @@
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
- int index, std::vector<bool>* out) {
- xla::Literal literal;
+ absl::string_view name, std::vector<bool>* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ResolveInputDynamismIntoPredVector(index, out);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
+ bool* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ResolveInputDynamismIntoPred(index, out);
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismReshaped(
+ int index, absl::Span<const int64_t> new_dims,
+ xla::Literal* dynamism_literal) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
if (!dynamism_or_status.ok()) {
+ xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
// When failed to resolve dynamism, conservatively consider the value
// dynamic. This could happen if the input depends on some ops like
// custom-call that is not supported generally for dynamism computation.
- //
- // TODO(b/176993339): Support resolving dynamism across computations so
- // resolving dynamism will not fail in those cases.
- out->resize(InputShape(index).num_elements(), true);
+ *dynamism_literal =
+ true_literal
+ .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
+ .ValueOrDie();
+
return Status::OK();
}
Tensor dynamism = dynamism_or_status.ValueOrDie();
Tensor temp(dynamism.dtype());
- TensorShape tensor_shape({InputShape(index).num_elements()});
- if (!temp.CopyFrom(dynamism, tensor_shape)) {
+ if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
- dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
+ dynamism.shape().DebugString(),
+ " but was asked to be reshaped to incompatible shape ",
+ TensorShape(new_dims).DebugString());
}
- TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
+ TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
+ return Status::OK();
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
+ int index, std::vector<bool>* out) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
+ index, {InputShape(index).num_elements()}, &literal));
+
return LiteralToPredVector(literal, out);
}
@@ -363,7 +406,7 @@
absl::string_view name, std::vector<int64_t>* out,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
- return ConstantInputAsIntVector(index, out);
+ return ConstantInputAsIntVector(index, out, mode);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 29d40ea..e713607 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -69,6 +69,8 @@
// Returns the XLA XlaBuilder containing the output of compilation.
xla::XlaBuilder* builder() const;
+ xla::ValueInference& value_inference();
+
// Inputs
// Returns the number of inputs to the operator.
@@ -127,6 +129,17 @@
// predicates.
Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
Status ResolveInputDynamismIntoPred(int index, bool* out);
+ Status ResolveInputDynamismIntoPredVector(absl::string_view name,
+ std::vector<bool>* out);
+ Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out);
+
+ Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal);
+ Status ResolveInputDynamism(absl::string_view name,
+ xla::Literal* dynamism_literal);
+
+ Status ResolveInputDynamismReshaped(int index,
+ absl::Span<const int64_t> new_dims,
+ xla::Literal* dynamism_literal);
// Helper methods for constant inputs.
// Evaluates input `index` and stores it in `*constant_literal`. If the
@@ -329,7 +342,6 @@
private:
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
-
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
@@ -342,6 +354,7 @@
OpKernelContext* const context_;
bool dynamic_dimension_is_minus_one_;
+ xla::ValueInference value_inference_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 6f70a13..e73dd0f 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -142,6 +142,7 @@
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:value_inference",
"//tensorflow/compiler/xla/client:xla_builder",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
index c4abf9e..e9541ec 100644
--- a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
+++ b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.cc
@@ -109,4 +109,22 @@
false_operand, false_computation_rewritten);
});
}
+
+StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
+ XlaOp operand, XlaOp dimension_size,
+ int64_t dimension) {
+ auto inferred_bound_status_or = value_inference->AnalyzeConstant(
+ dimension_size, xla::ValueInferenceMode::kUpperBound);
+ TF_RETURN_IF_ERROR(inferred_bound_status_or.status());
+ if (inferred_bound_status_or->AllValid()) {
+ int64_t inferred_bound = inferred_bound_status_or->Get<int32>({}).value();
+ TF_ASSIGN_OR_RETURN(auto* shape_ptr,
+ operand.builder()->GetShapePtr(operand));
+ // Found a tighter bound, do a slice.
+ if (shape_ptr->dimensions(dimension) > inferred_bound)
+ operand = xla::SliceInDim(operand, 0, inferred_bound, 1, dimension);
+ }
+ operand = xla::SetDimensionSize(operand, dimension_size, dimension);
+ return operand;
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
index 4505376..cf318aa 100644
--- a/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
+++ b/tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -32,6 +33,12 @@
const XlaComputation& true_computation,
XlaOp false_operand,
const XlaComputation& false_computation);
+
+// Similar to SetDimensionSize, but automatically adjust the bound of output if
+// a tighter one can be inferred by `value_inference`.
+StatusOr<XlaOp> SetDimensionSizeWithRebound(ValueInference* value_inference,
+ XlaOp operand, XlaOp dimension_size,
+ int64_t dimension);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index d91155f..0f26284 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -91,7 +91,7 @@
if (element_type == TOKEN) {
return LiteralUtil::CreateToken();
}
- Literal literal = LiteralUtil::Zero(element_type);
+ Literal literal = LiteralUtil::One(element_type);
return literal.Broadcast(reference_shape, {}).ValueOrDie();
}
@@ -374,7 +374,34 @@
return false;
}
- absl::flat_hash_map<int64_t, Literal> evaluated;
+ struct CacheKey {
+ CacheKey(int64_t handle, InferenceContext context,
+ PostorderDFSNodeType type)
+ : handle(handle), context(context), type(type) {}
+ int64_t handle;
+ InferenceContext context;
+ PostorderDFSNodeType type;
+
+ template <typename H>
+ friend H AbslHashValue(H h, const CacheKey& key) {
+ h = H::combine(std::move(h), key.handle);
+ h = H::combine(std::move(h), key.context.shape_index.ToString());
+ h = H::combine(std::move(h),
+ VectorString(key.context.caller_operand_handles));
+ h = H::combine(std::move(h), key.type);
+ return h;
+ }
+
+ friend bool operator==(const CacheKey& lhs, const CacheKey& rhs) {
+ return lhs.handle == rhs.handle &&
+ lhs.context.shape_index == rhs.context.shape_index &&
+ lhs.context.caller_operand_handles ==
+ rhs.context.caller_operand_handles &&
+ lhs.type == rhs.type;
+ }
+ };
+
+ absl::flat_hash_map<CacheKey, Literal> evaluated;
HandleToInstruction handle_to_instruction;
HandleToComputation handle_to_computation;
};
@@ -405,6 +432,8 @@
case HloOpcode::kWhile:
case HloOpcode::kSend:
case HloOpcode::kRecv:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kRecvDone:
case HloOpcode::kParameter: {
if (opcode == HloOpcode::kParameter &&
!context.caller_operand_handles.empty()) {
@@ -665,6 +694,11 @@
return Literal::CreateFromProto(root->literal());
}
});
+ } else if (root->custom_call_target() == "Sharding") {
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantUpperBound, context)
+ .AddVisit([](Literal operand) { return operand; });
}
return InvalidArgument(
"Upper-bound inferencing on custom call %s is not supported",
@@ -793,6 +827,11 @@
PostorderDFSNodeType::kConstantValue, context)
.AddVisit(
[](Literal operand) -> StatusOr<Literal> { return operand; });
+ } else if (root->custom_call_target() == "Sharding") {
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantValue, context)
+ .AddVisit([](Literal operand) { return operand; });
} else {
return PostorderDFSNode().AddVisit(
[root, context](absl::Span<Literal>) {
@@ -805,9 +844,11 @@
}
case HloOpcode::kSort: {
PostorderDFSNode result;
+ InferenceContext dep_context = context;
+ dep_context.shape_index = {};
for (auto operand_id : root->operand_ids()) {
result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue,
- context);
+ dep_context);
}
const HloComputationProto* computation_proto =
handle_to_computation(root->called_computation_ids(0));
@@ -908,17 +949,24 @@
});
}
case HloOpcode::kSetDimensionSize:
- return result.AddVisit([root, type]() {
+ return result.AddVisit([root, type](absl::Span<Literal> operands) {
+ bool any_dynamic_operand = absl::c_any_of(
+ operands, [](Literal& operand) { return !operand.IsAll(0); });
// If values in a tensor `t` with bound are [e0, e1, e2...], we can say
// the max value of each position is [max(t), max(t), max(t), ...]. The
// effective size of this tensor doesn't change the max value.
return CreatePredLiteral(
- type == PostorderDFSNodeType::kValueIsDynamic,
+ type == PostorderDFSNodeType::kValueIsDynamic &&
+ any_dynamic_operand,
ShapeUtil::MakeStaticShape(Shape(root->shape())));
});
case HloOpcode::kDynamicSlice: {
- return result.AddVisit(
- [root]() { return CreatePredLiteral(true, Shape(root->shape())); });
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ // If any of the operand is dynamic, we say output is dynamic.
+ bool any_dynamic_operand = absl::c_any_of(
+ operands, [](Literal& operand) { return !operand.IsAll(0); });
+ return CreatePredLiteral(any_dynamic_operand, Shape(root->shape()));
+ });
}
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -1182,7 +1230,6 @@
Literal lhs = std::move(operands[2]);
Literal rhs = std::move(operands[3]);
auto result = CreatePredLiteral(true, Shape(root->shape()));
-
result.MutableEachCell<bool>(
[&](absl::Span<const int64_t> indices, bool value) {
absl::optional<bool> optional_selector =
@@ -1248,6 +1295,8 @@
}
}
});
+ } else if (root->custom_call_target() == "Sharding") {
+ return result.AddVisit([](Literal operand) { return operand; });
} else {
return InvalidArgument(
"Dynamic inferencing on custom call %s is not supported",
@@ -1258,6 +1307,10 @@
}
case HloOpcode::kCall:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
case HloOpcode::kWhile: {
return PostorderDFSNode().AddVisit([root,
context]() -> StatusOr<Literal> {
@@ -1268,8 +1321,12 @@
break;
}
default:
- return Unimplemented("Can't infer dynamism through %s: %s",
- root->opcode(), root->DebugString());
+ return PostorderDFSNode().AddVisit([root,
+ context]() -> StatusOr<Literal> {
+ return CreatePredLiteral(
+ true,
+ ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
+ });
}
}
@@ -1296,8 +1353,10 @@
VisitState state;
Visit visit; // The handler to call once the dependencies are resolved into
// literal form.
- int64_t id; // Unique id in the work queue, starting from 0.
- std::vector<int64_t> dependencies;
+ int64_t id; // Unique id in the work queue, starting from 0.
+ std::vector<CacheKey> dependencies;
+
+ CacheKey GetCacheKey() { return CacheKey(handle, context, type); }
};
std::vector<WorkItem> stack;
@@ -1315,21 +1374,25 @@
// Gather dependencies and transform them into literals.
std::vector<Literal> literals;
- for (int64_t dep_id : item.dependencies) {
- TF_RET_CHECK(evaluated.contains(dep_id));
- literals.emplace_back(evaluated.at(dep_id).Clone());
+ for (CacheKey& dep_key : item.dependencies) {
+ TF_RET_CHECK(evaluated.contains(dep_key));
+ literals.emplace_back(evaluated.at(dep_key).Clone());
}
VLOG(1) << "Start visiting with dependency type: "
<< PostorderDFSNodeTypeToString(item.type);
TF_ASSIGN_OR_RETURN(auto literal, item.visit(absl::MakeSpan(literals)));
VLOG(1) << "End visiting: " << literal.ToString();
- evaluated[item.id] = std::move(literal);
+ evaluated[item.GetCacheKey()] = std::move(literal);
stack.pop_back();
continue;
}
// This is the first time we see this node, we want to gather its
// dependenceis.
VLOG(1) << "unvisited";
+ if (evaluated.contains(item.GetCacheKey())) {
+ stack.pop_back();
+ continue;
+ }
item.state = kVisiting;
PostorderDFSNode node;
switch (item.type) {
@@ -1360,22 +1423,21 @@
// resolved.
item.visit = node.visit;
- // Dependencies of this item have id in the range of [unique_id, unique_id +
- // dependencies.size())
- for (int64_t i = 0; i < node.dependencies.size(); ++i) {
- item.dependencies.push_back(unique_id + i);
- }
+ const int64_t current_item_id = stack.size() - 1;
// Enqueue dependencies into the stack. `item` shouldn't be accessed after
// this point.
for (const PostorderDFSDep& dep : node.dependencies) {
- VLOG(1) << "dep " << dep.annotation << ":"
- << handle_to_instruction(dep.handle)->DebugString();
+ VLOG(1) << "dep " << dep.annotation
+ << "::" << handle_to_instruction(dep.handle)->DebugString()
+ << "index" << dep.context.shape_index
+ << " stack size:" << stack.size();
stack.emplace_back(dep.handle, dep.context, dep.type, kUnvisited,
unique_id++);
+ stack[current_item_id].dependencies.push_back(stack.back().GetCacheKey());
}
}
- VLOG(1) << "done" << evaluated[root.id].ToString();
- return evaluated[root.id].Clone();
+ VLOG(1) << "done" << evaluated[root.GetCacheKey()].ToString();
+ return evaluated[root.GetCacheKey()].Clone();
}
StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
diff --git a/tensorflow/compiler/xla/tests/value_inference_test.cc b/tensorflow/compiler/xla/tests/value_inference_test.cc
index 9850b63..55bd8b0 100644
--- a/tensorflow/compiler/xla/tests/value_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/value_inference_test.cc
@@ -259,6 +259,16 @@
EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
}
+TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
+ XlaBuilder b(TestName());
+
+ auto constant = ConstantR1<int32>(&b, {0, 1, 2, 3});
+ auto slice_start = ConstantR0(&b, 1);
+ auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
+ EXPECT_FALSE(
+ ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
+}
+
TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
XlaBuilder b(TestName());
// Test the analysis on a gather where first operand and second operand have
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
index 1c9deaf..fa9dbb1 100644
--- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -48,7 +48,9 @@
TensorShape indices_shape = ctx->InputShape(1);
int64_t num_segments;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntScalar(
+ 2, &num_segments, xla::ValueInferenceMode::kUpperBound));
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
errors::InvalidArgument(