[Resubmit]Value inference on upper- and lower- bounds.
- This cl allows xla to infer both the upper and lower bounds of common ops.
- The client can trigger the inference through xla op kernel API.
- Useful to infer bounds of subtract and divide case:
T = # some_dynamic_tensor
tf.range(x - tf.size(T)) # Bound of tf.range is x, instead of x-tf.size(T).
PiperOrigin-RevId: 367139035
Change-Id: I0fe76aa631af287a532c1b99a5e3d4452fbf711c
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index edc89aa..4f93d3e 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -362,6 +362,7 @@
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
+ "//tensorflow/compiler/xla/client:value_inference",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index dfbad70..04b97ce 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -84,8 +84,10 @@
errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in_shape.DebugString()));
xla::Literal start, limit, delta;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
- OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(
+ 0, &start, xla::ValueInferenceMode::kLowerBound));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(
+ 1, &limit, xla::ValueInferenceMode::kUpperBound));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
DataType type = input_type(0);
@@ -108,23 +110,29 @@
DataTypeString(type));
}
OP_REQUIRES_OK(ctx, output.status());
+ bool start_is_dynamic = false;
+ OP_REQUIRES_OK(ctx,
+ ctx->ResolveInputDynamismIntoPred(0, &start_is_dynamic));
+ bool limit_is_dynamic = false;
+ OP_REQUIRES_OK(ctx,
+ ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
- if (type == DT_INT32 || type == DT_INT64) {
- bool limit_is_dynamic = false;
- OP_REQUIRES_OK(ctx,
- ctx->ResolveInputDynamismIntoPred(1, &limit_is_dynamic));
- if (type == DT_INT32) {
- if (limit_is_dynamic) {
- output = xla::SetDimensionSize(output.ValueOrDie(), ctx->Input(1), 0);
- }
+ if (start_is_dynamic || limit_is_dynamic) {
+ xla::XlaOp delta = ctx->Input(2);
+ xla::XlaOp limit = ctx->Input(1);
+ xla::XlaOp start = ctx->Input(0);
+ if (type == DT_INT32 || type == DT_INT64) {
+ auto dynamic_size =
+ ((xla::Abs(limit - start) + xla::Abs(delta) -
+ xla::One(ctx->builder(), ctx->input_xla_type(0))) /
+ xla::Abs(delta));
+ output = xla::SetDimensionSize(output.ValueOrDie(), dynamic_size, 0);
} else {
- if (limit_is_dynamic) {
- output = xla::SetDimensionSize(
- output.ValueOrDie(),
- xla::ConvertElementType(ctx->Input(1), xla::S32), 0);
- }
+ auto dynamic_size = (xla::Ceil(xla::Abs((limit - start) / delta)));
+ output = xla::SetDimensionSize(output.ValueOrDie(), dynamic_size, 0);
}
}
+
ctx->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 099d546..87fdc63 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -152,7 +152,8 @@
}
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
- xla::Client* client, bool dynamic_dimension_is_minus_one) const {
+ xla::Client* client, bool dynamic_dimension_is_minus_one,
+ xla::ValueInferenceMode mode) const {
switch (kind()) {
case Kind::kConstant:
case Kind::kResource:
@@ -165,6 +166,22 @@
return errors::InvalidArgument(
"ResolveConstant called on XlaExpression: ", HumanString());
}
+ TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+ if (mode == xla::ValueInferenceMode::kLowerBound ||
+ mode == xla::ValueInferenceMode::kUpperBound) {
+ std::vector<int64> layout_indices(shape.dims());
+ std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+ xla::ValueInference value_inference(handle().builder());
+ TF_ASSIGN_OR_RETURN(xla::OptionalLiteral literal,
+ value_inference.AnalyzeConstant(handle(), mode));
+ if (!literal.GetValue().has_value()) {
+ return {absl::nullopt};
+ }
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(
+ LiteralToHostTensor(literal.GetValue().value(), dtype(), &tensor));
+ return {tensor};
+ }
TF_ASSIGN_OR_RETURN(bool is_constant,
handle().builder()->IsConstant(handle()));
@@ -179,8 +196,6 @@
handle().builder()->BuildConstantSubGraph(
handle(), dynamic_dimension_is_minus_one));
- TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
-
// The XLA layout is specified minor to major, and TensorFlow uses a major to
// minor order.
std::vector<int64> layout_indices(shape.dims());
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 408afef..69f367d 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -19,6 +19,7 @@
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/tensor.h"
@@ -106,16 +107,17 @@
// Returns a human-readable summary of the expression.
string HumanString() const;
- // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
+ // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns
// an erroneous XlaOp if the expression is not a constant or an expression.
xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
- // If a kXlaOp or kConstant expression can be resolved to a compile-time
+ // If a kXlaOp or kValue expression can be resolved to a compile-time
// constant, returns the value as a host-memory Tensor. Returns an empty
// optional if it cannot be resolved. Returns an error if passed a resource
// expression.
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
- xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
+ xla::Client* client, bool dynamic_dimension_is_minus_one = false,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue) const;
// ResolveDynamism computes where a value inside this op is dynamic or can be
// inferred at compile time.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index f2eb038..7b02df2 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -128,9 +128,11 @@
}
Status XlaOpKernelContext::ConstantInput(int index,
- xla::Literal* constant_literal) {
- return ConstantInputReshaped(
- index, context_->input(index).shape().dim_sizes(), constant_literal);
+ xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode) {
+ return ConstantInputReshaped(index,
+ context_->input(index).shape().dim_sizes(),
+ constant_literal, mode);
}
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
@@ -147,18 +149,19 @@
}
Status XlaOpKernelContext::ConstantInput(absl::string_view name,
- xla::Literal* constant_literal) {
+ xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
- return ConstantInput(index, constant_literal);
+ return ConstantInput(index, constant_literal, mode);
}
Status XlaOpKernelContext::ConstantInputReshaped(
- int index, absl::Span<const int64> new_dims,
- xla::Literal* constant_literal) {
+ int index, absl::Span<const int64> new_dims, xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
- e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
+ e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
@@ -225,21 +228,23 @@
return Status::OK();
}
-Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
+Status XlaOpKernelContext::ConstantInputAsIntScalar(
+ int index, int64* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
- TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Scalar(literal, out);
}
-Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
- int64* out) {
+Status XlaOpKernelContext::ConstantInputAsIntScalar(
+ absl::string_view name, int64* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
- return ConstantInputAsIntScalar(index, out);
+ return ConstantInputAsIntScalar(index, out, mode);
}
-Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
+Status XlaOpKernelContext::ConstantInputAsFloatScalar(
+ int index, double* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
- TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToFloat64Scalar(literal, out);
}
@@ -341,40 +346,42 @@
return Status::OK();
}
-Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
- std::vector<int64>* out) {
+Status XlaOpKernelContext::ConstantInputAsIntVector(
+ int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
- TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
-Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
- std::vector<int64>* out) {
+Status XlaOpKernelContext::ConstantInputAsIntVector(
+ absl::string_view name, std::vector<int64>* out,
+ xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsIntVector(index, out);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
- int index, std::vector<int64>* out) {
+ int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
- index, {InputShape(index).num_elements()}, &literal));
+ index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
- absl::string_view name, std::vector<int64>* out) {
+ absl::string_view name, std::vector<int64>* out,
+ xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
- index, {InputShape(index).num_elements()}, &literal));
+ index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
-Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
- xla::Literal* out) {
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(
+ int index, xla::Literal* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
- TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
switch (literal.shape().element_type()) {
case xla::S32: {
*out = xla::Literal(
@@ -396,17 +403,18 @@
}
}
-Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
- xla::Literal* out) {
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(
+ absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
- return ConstantInputAsInt64Literal(index, out);
+ return ConstantInputAsInt64Literal(index, out, mode);
}
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
-Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
+Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
+ xla::ValueInferenceMode mode) {
xla::Literal literal;
- TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
std::vector<int64> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
*shape = TensorShape(dims);
@@ -449,13 +457,14 @@
return Status::OK();
}
-Status XlaOpKernelContext::ConstantInputList(
- absl::string_view name, std::vector<xla::Literal>* outputs) {
+Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
+ std::vector<xla::Literal>* outputs,
+ xla::ValueInferenceMode mode) {
int start, stop;
TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
- TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
+ TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 48c66b9..4cfe730 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -20,6 +20,7 @@
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -130,34 +131,58 @@
// Evaluates input `index` and stores it in `*constant_literal`. If the
// expression cannot be evaluated, e.g., because it depends on unbound
- // parameters, returns a non-OK status.
- Status ConstantInput(int index, xla::Literal* constant_literal);
- Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
+ // parameters, returns a non-OK status. This function can also be used to
+ // infer constant input upper or lower bounds, by changing the `mode`
+ // parameter.
+ Status ConstantInput(
+ int index, xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+ Status ConstantInput(
+ absl::string_view name, xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant scalar int32 or int64 tensor into an int64.
- Status ConstantInputAsIntScalar(int index, int64* out);
- Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
+ Status ConstantInputAsIntScalar(
+ int index, int64* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+ Status ConstantInputAsIntScalar(
+ absl::string_view name, int64* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant scalar float32 or float64 tensor into a float64.
- Status ConstantInputAsFloatScalar(int index, double* out);
+ Status ConstantInputAsFloatScalar(
+ int index, double* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
- Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
- Status ConstantInputAsIntVector(absl::string_view name,
- std::vector<int64>* out);
+ Status ConstantInputAsIntVector(
+ int index, std::vector<int64>* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+ Status ConstantInputAsIntVector(
+ absl::string_view name, std::vector<int64>* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Reshapes and converts a constant int32 or int64 tensor into a vector of
// int64s.
- Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
- Status ConstantInputReshapedToIntVector(absl::string_view name,
- std::vector<int64>* out);
+ Status ConstantInputReshapedToIntVector(
+ int index, std::vector<int64>* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+ Status ConstantInputReshapedToIntVector(
+ absl::string_view name, std::vector<int64>* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
- Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
- Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
+ Status ConstantInputAsInt64Literal(
+ int index, xla::Literal* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
+ Status ConstantInputAsInt64Literal(
+ absl::string_view name, xla::Literal* out,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
- Status ConstantInputAsShape(int index, TensorShape* shape);
+ Status ConstantInputAsShape(
+ int index, TensorShape* shape,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
// into a PartialTensorShape.
@@ -166,8 +191,9 @@
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list.
- Status ConstantInputList(absl::string_view name,
- std::vector<xla::Literal>* literals);
+ Status ConstantInputList(
+ absl::string_view name, std::vector<xla::Literal>* outputs,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
// Returns an XlaExpression describing the value of 'index'.
const XlaExpression& InputExpression(int index);
@@ -309,8 +335,10 @@
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
- Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
- xla::Literal* constant_literal);
+ Status ConstantInputReshaped(
+ int index, absl::Span<const int64> new_dims,
+ xla::Literal* constant_literal,
+ xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
OpKernelContext* const context_;
bool dynamic_dimension_is_minus_one_;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 0d6e283..eae1816 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -218,13 +218,17 @@
visibility = ["//visibility:public"],
deps = [
":xla_builder",
+ "//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 41c2c92..7dd2395 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -13,15 +13,24 @@
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/value_inference.h"
+#include <utility>
+#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace {
@@ -67,15 +76,13 @@
}
}
-using GetOperand = std::function<StatusOr<LiteralSlice>(int64 operand_index,
- int64 opreand_handle)>;
+
// HloProtoEvaluator evaluates an hlo proto and returns a literal. The user has
// to provide operand as literals through the get_operand function.
struct HloProtoEvaluator {
- explicit HloProtoEvaluator(HloInstructionProto inst, GetOperand get_operand)
+ explicit HloProtoEvaluator(HloInstructionProto inst)
: inst(std::move(inst)),
- get_operand(get_operand),
module("EmptyModuleForEvaluation", HloModuleConfig()) {}
// WithOpCode changes the called computation of the instruction being
@@ -103,6 +110,12 @@
return *this;
}
+ // WithOpCode changes the opcode of the instruction being evaluated.
+ HloProtoEvaluator& WithOperands(absl::Span<Literal> operands) {
+ this->operands = operands;
+ return *this;
+ }
+
StatusOr<Literal> Evaluate() {
// Evaluate the instruction by swapping it's operands with constant
// instructions with given literals.
@@ -110,9 +123,8 @@
absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < inst.operand_ids_size(); ++i) {
int64 operand_handle = inst.operand_ids(i);
- TF_ASSIGN_OR_RETURN(auto literal, get_operand(i, inst.operand_ids(i)));
std::unique_ptr<HloInstruction> operand =
- HloInstruction::CreateConstant(literal.Clone());
+ HloInstruction::CreateConstant(operands[i].Clone());
operand_map[operand_handle] = operand.get();
builder.AddInstruction(std::move(operand));
}
@@ -144,108 +156,448 @@
}
HloInstructionProto inst;
- GetOperand get_operand;
+
HloModule module;
+ absl::Span<Literal> operands;
HloComputation* computation = nullptr;
absl::optional<PrimitiveType> primitive_type = absl::nullopt;
absl::optional<HloOpcode> opcode = absl::nullopt;
};
+
+enum PostorderDFSNodeType {
+ // This node is about figuring out the constant value.
+ kConstantValue = 0,
+ // This node is about figuring out the constant bound.
+ kConstantUpperBound,
+ kConstantLowerBound,
+ // This node is about figuring out whether a value is dynamic.
+ kValueIsDynamic,
+ // This node is about figuring out whether a bound value is dynamic. It's
+ // similar to kValueIsDynamic, but views shape bound as static values.
+ kBoundIsDynamic,
+};
+
+// Each node in the postorder traversal tree may depend on traversing the
+// values of the node's children.
+struct PostorderDFSDep {
+ explicit PostorderDFSDep(int64 handle, PostorderDFSNodeType type)
+ : handle(handle), type(type) {}
+ int64 handle;
+ PostorderDFSNodeType type;
+};
+
+// This function represents the logic to visit a node once its dependencies
+// (operands) are all resolved.
+using Visit = std::function<StatusOr<Literal>(absl::Span<Literal>)>;
+// Convenient specializations of Visit function for different operands.
+using Visit0D = std::function<StatusOr<Literal>()>;
+using Visit1D = std::function<StatusOr<Literal>(Literal)>;
+using Visit2D = std::function<StatusOr<Literal>(Literal, Literal)>;
+
+// A postorder dfs node can be visited once its dependency requests are all
+// fulfilled.
+struct PostorderDFSNode {
+ PostorderDFSNode& AddDependency(int64 handle, PostorderDFSNodeType type) {
+ dependencies.emplace_back(handle, type);
+ return *this;
+ }
+
+ PostorderDFSNode& AddVisit(const Visit& visit) {
+ this->visit = visit;
+ return *this;
+ }
+
+ PostorderDFSNode& AddVisit(const Visit0D& visit) {
+ this->visit = [visit](absl::Span<Literal> literals) { return visit(); };
+ return *this;
+ }
+
+ PostorderDFSNode& AddVisit(const Visit1D& visit) {
+ this->visit = [visit](absl::Span<Literal> literals) {
+ return visit(std::move(literals[0]));
+ };
+ return *this;
+ }
+
+ PostorderDFSNode& AddVisit(const Visit2D& visit) {
+ this->visit = [visit](absl::Span<Literal> literals) {
+ return visit(std::move(literals[0]), std::move(literals[1]));
+ };
+ return *this;
+ }
+
+ std::vector<PostorderDFSDep> dependencies;
+ Visit visit;
+};
+
+// Convert an interger handle to HloInstructionProto.
+using HandleToInstruction = std::function<const HloInstructionProto*(int64)>;
+using HandleToComputation = std::function<const HloComputationProto*(int64)>;
+
+struct PostorderDFSVisitor {
+ PostorderDFSVisitor(HandleToInstruction handle_to_instruction,
+ HandleToComputation handle_to_computation)
+ : handle_to_instruction(handle_to_instruction),
+ handle_to_computation(handle_to_computation) {}
+
+ StatusOr<PostorderDFSNode> AnalyzeUpperBound(int64 handle);
+ StatusOr<PostorderDFSNode> AnalyzeLowerBound(int64 handle);
+ StatusOr<PostorderDFSNode> AnalyzeIsDynamic(int64 handle,
+ PostorderDFSNodeType type);
+ StatusOr<PostorderDFSNode> AnalyzeConstant(int64 handle);
+ StatusOr<PostorderDFSNode> AnalyzeConstantValueFallback(int64 handle,
+ PostorderDFSNodeType type);
+
+ StatusOr<Literal> PostOrderDFSVisit(int64 handle, PostorderDFSNodeType type);
+
+ // Returns true if a value represented by `handle` is an integeral type or
+ // just got converted from an integral type to floating point type.
+ bool IsValueEffectiveInteger(int64 handle) {
+ const HloInstructionProto* instr = handle_to_instruction(handle);
+ if (primitive_util::IsIntegralType(instr->shape().element_type())) {
+ return true;
+ }
+ // Also returns true if this is a convert that converts an integer to float.
+ HloOpcode opcode = StringToHloOpcode(instr->opcode()).ValueOrDie();
+ if (opcode != HloOpcode::kConvert) {
+ return false;
+ }
+ const HloInstructionProto* parent =
+ handle_to_instruction(instr->operand_ids(0));
+ if (primitive_util::IsIntegralType(parent->shape().element_type())) {
+ return true;
+ }
+ return false;
+ }
+
+ absl::flat_hash_map<std::pair<int64, PostorderDFSNodeType>, Literal>
+ evaluated;
+ HandleToInstruction handle_to_instruction;
+ HandleToComputation handle_to_computation;
+};
+
} // namespace
-StatusOr<Literal> ValueInference::AnalyzeConstantLiteral(int64 handle) {
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
- builder_->LookUpInstructionByHandle(handle));
+// Analyze a tensor's constant value, upper-bound value or lower-bound value.
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstantValueFallback(
+ int64 handle, PostorderDFSNodeType type) {
+ const HloInstructionProto* root = handle_to_instruction(handle);
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+ PostorderDFSNode result;
+ for (auto operand_id : root->operand_ids()) {
+ result.AddDependency(operand_id, type);
+ }
switch (opcode) {
- case HloOpcode::kGetDimensionSize: {
- int64 dimension = root->dimensions(0);
- int64 operand_handle = root->operand_ids(0);
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
- builder_->LookUpInstructionByHandle(operand_handle));
- if (operand_proto->shape().is_dynamic_dimension(dimension)) {
- // The value is dynamic. We return a garbage literal here, which
- // will be masked out later.
- return CreateGarbageLiteral(Shape(root->shape()));
- } else {
- return LiteralUtil::CreateR0<int32>(
- operand_proto->shape().dimensions(dimension));
- }
- }
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kAllReduce:
- // TODO(b/33009255): Implement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCall:
- // TODO(b/32495713): We aren't checking the to_apply computation itself,
- // so we conservatively say that computations containing the Call op
- // cannot be constant. We cannot set is_functional=false in other similar
- // cases since we're already relying on IsConstant to return true.
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
case HloOpcode::kConditional:
- // TODO(b/32495713): We aren't checking the condition and body
- // computations themselves.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter: {
- // The value is dynamic. We return a garbage literal here, which
- // will be masked out later.
- return CreateGarbageLiteral(Shape(root->shape()));
+ return result.AddVisit([root](absl::Span<Literal>) {
+ // The value is dynamic. We return a garbage literal here, which
+ // will be masked out later.
+ return CreateGarbageLiteral(Shape(root->shape()));
+ });
+ }
+ // Subtract and Divide use lower-bound as second operand.
+ case HloOpcode::kSubtract:
+ case HloOpcode::kCos:
+ case HloOpcode::kSin:
+ case HloOpcode::kNegate:
+ case HloOpcode::kAbs:
+ case HloOpcode::kDivide:
+ case HloOpcode::kGetDimensionSize: {
+ return InvalidArgument("AnalyzeConstantValue can't handle opcode: %s",
+ root->opcode());
}
case HloOpcode::kGetTupleElement: {
int64 operand_handle = root->operand_ids(0);
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
- builder_->LookUpInstructionByHandle(operand_handle));
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
- // Don't materialize the whole parameter if it's followed by a GTE.
- return CreateGarbageLiteral(Shape(root->shape()));
+ return PostorderDFSNode().AddVisit([root](absl::Span<Literal>) {
+ // The value is dynamic. We return a garbage literal here, which
+ // will be masked out later.
+ return CreateGarbageLiteral(Shape(root->shape()));
+ });
}
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeConstant(operand_handle);
- })
- .WithPrimitiveType(PRED)
- .Evaluate();
+
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .Evaluate();
+ });
}
case HloOpcode::kReduce:
case HloOpcode::kScatter:
case HloOpcode::kReduceWindow: {
- HloComputationProto computation_proto =
- builder_->embedded_[root->called_computation_ids(0)];
- TF_ASSIGN_OR_RETURN(auto computation, HloComputation::CreateFromProto(
- computation_proto, {}));
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeConstant(operand_handle);
- })
- .WithComputation(std::move(computation))
- .Evaluate();
+ const HloComputationProto* computation_proto =
+ handle_to_computation(root->called_computation_ids(0));
+ return result.AddVisit(
+ [root, computation_proto](
+ absl::Span<Literal> operands) -> StatusOr<Literal> {
+ TF_ASSIGN_OR_RETURN(
+ auto computation,
+ HloComputation::CreateFromProto(*computation_proto, {}));
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .WithComputation(std::move(computation))
+ .Evaluate();
+ });
}
- default:
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeConstant(operand_handle);
- })
- .Evaluate();
+ default: {
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+ });
+ }
}
}
-StatusOr<Literal> ValueInference::AnalyzeIsDynamicLiteral(int64 handle) {
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
- builder_->LookUpInstructionByHandle(handle));
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeUpperBound(
+ int64 handle) {
+ const HloInstructionProto* root = handle_to_instruction(handle);
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
switch (opcode) {
case HloOpcode::kGetDimensionSize: {
int64 dimension = root->dimensions(0);
int64 operand_handle = root->operand_ids(0);
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
- builder_->LookUpInstructionByHandle(operand_handle));
- return LiteralUtil::CreateR0<bool>(
- operand_proto->shape().is_dynamic_dimension(dimension));
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
+ return PostorderDFSNode().AddVisit(
+ [operand_proto, dimension]() -> StatusOr<Literal> {
+ return LiteralUtil::CreateR0<int32>(
+ operand_proto->shape().dimensions(dimension));
+ });
+ }
+ case HloOpcode::kAbs: {
+ // upper-bound(abs(operand)) = max(abs(lower-bound(operand)),
+ // abs(upper-bound(operand)))
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantLowerBound)
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantUpperBound)
+ .AddVisit([](Literal lower_bound,
+ Literal upper_bound) -> StatusOr<Literal> {
+ HloEvaluator evaluator;
+ TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
+ evaluator.EvaluateElementwiseUnaryOp(
+ HloOpcode::kAbs, lower_bound));
+ TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
+ evaluator.EvaluateElementwiseUnaryOp(
+ HloOpcode::kAbs, upper_bound));
+ return evaluator.EvaluateElementwiseBinaryOp(
+ HloOpcode::kMaximum, lower_bound_abs, upper_bound_abs);
+ });
+ }
+ case HloOpcode::kNegate: {
+ // upper-bound(negate(operand)) = negate(lower-bound(operand))
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantLowerBound)
+ .AddVisit([](Literal lower_bound) -> StatusOr<Literal> {
+ HloEvaluator evaluator;
+ return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
+ lower_bound);
+ });
+ }
+ case HloOpcode::kSubtract:
+ case HloOpcode::kDivide: {
+ // Lower-bound is used for second operand of subtract and divide.
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantUpperBound)
+ .AddDependency(root->operand_ids(1),
+ PostorderDFSNodeType::kConstantLowerBound)
+ .AddVisit(
+ [root, opcode, this](Literal upper_bound,
+ Literal lower_bound) -> StatusOr<Literal> {
+ if (opcode == HloOpcode::kDivide &&
+ this->IsValueEffectiveInteger(root->operand_ids(1))) {
+ // Because in many cases the lower bound of a value is
+ // integer 0, instead of throwing an divide-by-zero error
+ // at compile time, we set the bound defer the check to
+ // runtime. In those cases we use the upper-bound of
+ // first operand as a placeholder.
+ HloEvaluator evaluator;
+ auto zero =
+ LiteralUtil::Zero(lower_bound.shape().element_type());
+ zero = zero.Broadcast(lower_bound.shape(), {}).ValueOrDie();
+ TF_ASSIGN_OR_RETURN(
+ auto lower_bound_is_zero,
+ evaluator.EvaluateElementwiseCompareOp(
+ ComparisonDirection::kEq, lower_bound, zero));
+
+ auto one =
+ LiteralUtil::One(lower_bound.shape().element_type());
+ one = one.Broadcast(lower_bound.shape(), {}).ValueOrDie();
+ TF_ASSIGN_OR_RETURN(
+ lower_bound, evaluator.EvaluateElementwiseTernaryOp(
+ HloOpcode::kSelect, lower_bound_is_zero,
+ one, lower_bound));
+ }
+ std::vector<Literal> new_operands;
+ new_operands.emplace_back(std::move(upper_bound));
+ new_operands.emplace_back(std::move(lower_bound));
+ return HloProtoEvaluator(*root)
+ .WithOperands(absl::MakeSpan(new_operands))
+ .Evaluate();
+ });
+ }
+ default:
+ return AnalyzeConstantValueFallback(
+ handle, PostorderDFSNodeType::kConstantUpperBound);
+ }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeLowerBound(
+ int64 handle) {
+ const HloInstructionProto* root = handle_to_instruction(handle);
+ TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+ switch (opcode) {
+ case HloOpcode::kGetDimensionSize: {
+ int64 dimension = root->dimensions(0);
+ int64 operand_handle = root->operand_ids(0);
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
+ return PostorderDFSNode().AddVisit(
+ [dimension, operand_proto]() -> StatusOr<Literal> {
+ if (operand_proto->shape().is_dynamic_dimension(dimension)) {
+ return LiteralUtil::CreateR0<int32>(0);
+ } else {
+ return LiteralUtil::CreateR0<int32>(
+ operand_proto->shape().dimensions(dimension));
+ }
+ });
+ }
+ case HloOpcode::kAbs: {
+ // lower-bound(abs(operand)) = min(abs(lower-bound(operand)),
+ // abs(upper-bound(operand)))
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantLowerBound)
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantUpperBound)
+ .AddVisit([](Literal lower_bound,
+ Literal upper_bound) -> StatusOr<Literal> {
+ HloEvaluator evaluator;
+ TF_ASSIGN_OR_RETURN(auto lower_bound_abs,
+ evaluator.EvaluateElementwiseUnaryOp(
+ HloOpcode::kAbs, lower_bound));
+ TF_ASSIGN_OR_RETURN(auto upper_bound_abs,
+ evaluator.EvaluateElementwiseUnaryOp(
+ HloOpcode::kAbs, upper_bound));
+ return evaluator.EvaluateElementwiseBinaryOp(
+ HloOpcode::kMinimum, lower_bound_abs, upper_bound_abs);
+ });
+ }
+ case HloOpcode::kNegate: {
+ // lower-bound(negate(operand)) = negate(upper-bound(operand))
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantUpperBound)
+ .AddVisit([](Literal upper_bound) -> StatusOr<Literal> {
+ HloEvaluator evaluator;
+ return evaluator.EvaluateElementwiseUnaryOp(HloOpcode::kNegate,
+ upper_bound);
+ });
+ }
+ case HloOpcode::kSubtract:
+ case HloOpcode::kDivide: {
+ // Upper bound is used for second operand of subtract and divide.
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantLowerBound)
+ .AddDependency(root->operand_ids(1),
+ PostorderDFSNodeType::kConstantUpperBound)
+ .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+ return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+ });
+ }
+ default:
+ return AnalyzeConstantValueFallback(
+ handle, PostorderDFSNodeType::kConstantLowerBound);
+ }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeConstant(
+ int64 handle) {
+ const HloInstructionProto* root = handle_to_instruction(handle);
+ HloOpcode opcode = StringToHloOpcode(root->opcode()).ValueOrDie();
+ switch (opcode) {
+ case HloOpcode::kGetDimensionSize: {
+ int64 dimension = root->dimensions(0);
+ int64 operand_handle = root->operand_ids(0);
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
+ return PostorderDFSNode().AddVisit(
+ [operand_proto, dimension, root]() -> StatusOr<Literal> {
+ if (operand_proto->shape().is_dynamic_dimension(dimension)) {
+ // The value is dynamic, we return garbage data here and mask them
+ // out later.
+ return CreateGarbageLiteral(Shape(root->shape()));
+ } else {
+ return LiteralUtil::CreateR0<int32>(
+ operand_proto->shape().dimensions(dimension));
+ }
+ });
+ }
+ case HloOpcode::kSubtract:
+ case HloOpcode::kCos:
+ case HloOpcode::kSin:
+ case HloOpcode::kNegate:
+ case HloOpcode::kAbs:
+ case HloOpcode::kDivide: {
+ PostorderDFSNode result;
+ for (auto operand_id : root->operand_ids()) {
+ result.AddDependency(operand_id, PostorderDFSNodeType::kConstantValue);
+ }
+ return result.AddVisit(
+ [root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+ return HloProtoEvaluator(*root).WithOperands(operands).Evaluate();
+ });
+ }
+ default:
+ return AnalyzeConstantValueFallback(handle,
+ PostorderDFSNodeType::kConstantValue);
+ }
+}
+
+StatusOr<PostorderDFSNode> PostorderDFSVisitor::AnalyzeIsDynamic(
+ int64 handle, PostorderDFSNodeType type) {
+ const HloInstructionProto* root = handle_to_instruction(handle);
+ // Invariant check.
+ TF_RET_CHECK(root);
+ TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+ PostorderDFSNode result;
+ for (auto operand_id : root->operand_ids()) {
+ result.AddDependency(operand_id, type);
+ }
+ switch (opcode) {
+ case HloOpcode::kGetDimensionSize: {
+ int64 dimension = root->dimensions(0);
+ int64 operand_handle = root->operand_ids(0);
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
+ return PostorderDFSNode().AddVisit([operand_proto, dimension,
+ type]() -> StatusOr<Literal> {
+ if (type == PostorderDFSNodeType::kBoundIsDynamic) {
+ // The bound of dynamic dimension is not dynamic.
+ return LiteralUtil::CreateR0<bool>(false);
+ }
+ // The value of dynamic dimension is dynamic.
+ return LiteralUtil::CreateR0<bool>(
+ operand_proto->shape().is_dynamic_dimension(dimension));
+ });
}
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -274,9 +626,7 @@
case HloOpcode::kCbrt:
case HloOpcode::kTanh: {
// Forward operand as they don't change if a value is dynamic or static.
- int64 operand_handle = root->operand_ids(0);
- TF_ASSIGN_OR_RETURN(auto literal, AnalyzeIsDynamic(operand_handle));
- return literal.Clone();
+ return result.AddVisit([](Literal operand) { return operand; });
}
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
@@ -295,13 +645,13 @@
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeIsDynamic(operand_handle);
- })
- .WithPrimitiveType(PRED)
- .WithOpCode(HloOpcode::kOr)
- .Evaluate();
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .WithPrimitiveType(PRED)
+ .WithOpCode(HloOpcode::kOr)
+ .Evaluate();
+ });
}
case HloOpcode::kTuple:
case HloOpcode::kTranspose:
@@ -310,112 +660,135 @@
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad: {
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeIsDynamic(operand_handle);
- })
- .WithPrimitiveType(PRED)
- .Evaluate();
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .WithPrimitiveType(PRED)
+ .Evaluate();
+ });
}
case HloOpcode::kGetTupleElement: {
int64 operand_handle = root->operand_ids(0);
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
- builder_->LookUpInstructionByHandle(operand_handle));
+ const HloInstructionProto* operand_proto =
+ handle_to_instruction(operand_handle);
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
- // Don't materialize the whole parameter if it's followed by a GTE.
- return CreatePredLiteral(true, Shape(root->shape()));
+ PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
+ // Don't materialize the whole parameter if it's followed by a GTE.
+ return CreatePredLiteral(true, Shape(root->shape()));
+ });
}
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeIsDynamic(operand_handle);
- })
- .WithPrimitiveType(PRED)
- .Evaluate();
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .WithPrimitiveType(PRED)
+ .Evaluate();
+ });
}
case HloOpcode::kReduce: {
- std::vector<std::unique_ptr<HloInstruction>> operand_storage;
- absl::flat_hash_map<int64, HloInstruction*> operand_map;
- absl::flat_hash_map<int64, HloComputation*> computation_map;
-
- Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
- HloComputation::Builder b("reduce_or");
- auto lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
- auto rhs = b.AddInstruction(
- HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
- b.AddInstruction(
- HloInstruction::CreateBinary(scalar_shape, HloOpcode::kOr, lhs, rhs));
- auto reduce_computation = b.Build();
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- return AnalyzeIsDynamic(operand_handle);
- })
- .WithPrimitiveType(PRED)
- .WithComputation(std::move(reduce_computation))
- .Evaluate();
+ return result.AddVisit([root](absl::Span<Literal> operands) {
+ Shape scalar_shape = ShapeUtil::MakeScalarShape(xla::PRED);
+ HloComputation::Builder b("reduce_or");
+ auto lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ b.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kOr, lhs, rhs));
+ auto reduce_computation = b.Build();
+ return HloProtoEvaluator(*root)
+ .WithOperands(operands)
+ .WithPrimitiveType(PRED)
+ .WithComputation(std::move(reduce_computation))
+ .Evaluate();
+ });
}
case HloOpcode::kConstant:
case HloOpcode::kIota: {
- return CreatePredLiteral(false, Shape(root->shape()));
+ return result.AddVisit(
+ [root]() { return CreatePredLiteral(false, Shape(root->shape())); });
}
case HloOpcode::kParameter: {
- return CreatePredLiteral(true, Shape(root->shape()));
+ return result.AddVisit(
+ [root]() { return CreatePredLiteral(true, Shape(root->shape())); });
}
case HloOpcode::kSelect: {
- TF_ASSIGN_OR_RETURN(OptionaLiteralSlice optional_selector_literal,
- AnalyzeOptionalConstant(root->operand_ids(0)));
- TF_ASSIGN_OR_RETURN(LiteralSlice lhs,
- AnalyzeIsDynamic(root->operand_ids(1)));
- TF_ASSIGN_OR_RETURN(LiteralSlice rhs,
- AnalyzeIsDynamic(root->operand_ids(2)));
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kConstantValue)
+ .AddDependency(root->operand_ids(0),
+ PostorderDFSNodeType::kValueIsDynamic)
+ // lhs dependency.
+ .AddDependency(root->operand_ids(1), type)
+ // rhs dependency.
+ .AddDependency(root->operand_ids(2), type)
+ .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+ OptionalLiteral optional_selector_literal(std::move(operands[0]),
+ std::move(operands[1]));
+ Literal lhs = std::move(operands[2]);
+ Literal rhs = std::move(operands[3]);
+ auto result = CreatePredLiteral(true, Shape(root->shape()));
- auto result = CreatePredLiteral(true, Shape(root->shape()));
+ result.MutableEachCell<bool>(
+ [&](absl::Span<const int64> indices, bool value) {
+ absl::optional<bool> optional_selector =
+ optional_selector_literal.Get<bool>(indices);
- result.MutableEachCell<bool>(
- [&](absl::Span<const int64> indices, bool value) {
- absl::optional<bool> optional_selector =
- optional_selector_literal.Get<bool>(indices);
-
- bool lhs_value = lhs.Get<bool>(indices);
- bool rhs_value = rhs.Get<bool>(indices);
- if (optional_selector.has_value()) {
- // Manually evaluate the selection without using Evaluator.
- if (*optional_selector) {
- return lhs_value;
- } else {
- return rhs_value;
- }
- } else {
- // Conservatively assume value is dynamic if selector is dynamic.
- return true;
- }
+ bool lhs_value = lhs.Get<bool>(indices);
+ bool rhs_value = rhs.Get<bool>(indices);
+ if (optional_selector.has_value()) {
+ // Manually evaluate the selection without using Evaluator.
+ if (*optional_selector) {
+ return lhs_value;
+ } else {
+ return rhs_value;
+ }
+ } else {
+ // Conservatively assume value is dynamic if selector is
+ // dynamic.
+ return true;
+ }
+ });
+ return result;
});
- return result;
}
case HloOpcode::kGather: {
- TF_ASSIGN_OR_RETURN(OptionaLiteralSlice optional_selector_literal,
- AnalyzeOptionalConstant(root->operand_ids(1)));
- if (!optional_selector_literal.AllValid()) {
- // Conservatively assume result are dynamic.
- return CreatePredLiteral(true, Shape(root->shape()));
- }
- return HloProtoEvaluator(*root,
- [&](int64 operand_index, int64 operand_handle) {
- if (operand_index == 1) {
- return AnalyzeConstant(operand_handle);
- } else {
- return AnalyzeIsDynamic(operand_handle);
- }
- })
- .WithPrimitiveType(PRED)
- .Evaluate();
+ return PostorderDFSNode()
+ .AddDependency(root->operand_ids(0), type)
+ .AddDependency(root->operand_ids(1),
+ PostorderDFSNodeType::kConstantValue)
+ .AddDependency(root->operand_ids(1),
+ PostorderDFSNodeType::kValueIsDynamic)
+ .AddVisit([root](absl::Span<Literal> operands) -> StatusOr<Literal> {
+ OptionalLiteral optional_selector_literal(std::move(operands[1]),
+ std::move(operands[2]));
+
+ if (!optional_selector_literal.AllValid()) {
+ // Conservatively assume results are dynamic.
+ return CreatePredLiteral(true, Shape(root->shape()));
+ }
+ std::vector<Literal> new_operands;
+ new_operands.emplace_back(std::move(operands[0]));
+ new_operands.emplace_back(
+ optional_selector_literal.GetValue()->Clone());
+
+ return HloProtoEvaluator(*root)
+ .WithOperands(absl::MakeSpan(new_operands))
+ .WithPrimitiveType(PRED)
+ .Evaluate();
+ });
}
case HloOpcode::kCustomCall: {
if (root->custom_call_target() == "SetBound") {
- return CreatePredLiteral(true, Shape(root->shape()));
+ return PostorderDFSNode().AddVisit([type, root]() -> StatusOr<Literal> {
+ if (type == PostorderDFSNodeType::kBoundIsDynamic) {
+ return CreatePredLiteral(false, Shape(root->shape()));
+ } else {
+ return CreatePredLiteral(true, Shape(root->shape()));
+ }
+ });
} else {
return InvalidArgument(
"Dynamic inferencing on custom call %s is not supported",
@@ -430,29 +803,138 @@
}
}
-StatusOr<LiteralSlice> ValueInference::AnalyzeIsDynamic(int64 handle) {
- if (is_dynamic_.contains(handle)) {
- return LiteralSlice(is_dynamic_[handle]);
+StatusOr<Literal> PostorderDFSVisitor::PostOrderDFSVisit(
+ int64 handle, PostorderDFSNodeType type) {
+ enum VisitState {
+ kUnvisited = 0,
+ kVisiting,
+ kVisited,
+ };
+
+ struct WorkItem {
+ explicit WorkItem(int64 handle, PostorderDFSNodeType type, VisitState state)
+ : handle(handle), type(type), state(state) {}
+ int64 handle;
+ PostorderDFSNodeType type;
+ VisitState state;
+ PostorderDFSNode node;
+ };
+
+ std::vector<WorkItem> stack;
+ stack.push_back(WorkItem(handle, type, kUnvisited));
+ while (!stack.empty()) {
+ WorkItem& item = stack.back();
+ VLOG(1) << "stack top" << handle_to_instruction(item.handle)->DebugString();
+ if (item.state == kVisiting) {
+ VLOG(1) << "visiting";
+ // The operands are ready, visit the node itself.
+
+ // Gather dependencies.
+ std::vector<Literal> literals;
+ for (const PostorderDFSDep& dep : item.node.dependencies) {
+ std::pair<int64, PostorderDFSNodeType> key(dep.handle, dep.type);
+ TF_RET_CHECK(evaluated.contains(key));
+ literals.emplace_back(evaluated.at(key).Clone());
+ }
+ VLOG(1) << "start visiting";
+ TF_ASSIGN_OR_RETURN(auto literal,
+ item.node.visit(absl::MakeSpan(literals)));
+ VLOG(1) << "end visiting: " << literal.ToString();
+ std::pair<int64, PostorderDFSNodeType> key(item.handle, item.type);
+ evaluated[key] = std::move(literal);
+ stack.pop_back();
+ continue;
+ }
+ // This is the first time we see this node, we want to gather its
+ // dependenceis.
+ VLOG(1) << "unvisited";
+ item.state = kVisiting;
+ PostorderDFSNode node;
+ switch (item.type) {
+ case PostorderDFSNodeType::kConstantValue: {
+ TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
+ break;
+ }
+ case PostorderDFSNodeType::kConstantLowerBound: {
+ TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
+ break;
+ }
+ case PostorderDFSNodeType::kConstantUpperBound: {
+ TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle));
+ break;
+ }
+ case PostorderDFSNodeType::kBoundIsDynamic:
+ case PostorderDFSNodeType::kValueIsDynamic: {
+ TF_ASSIGN_OR_RETURN(node, AnalyzeIsDynamic(item.handle, item.type));
+ break;
+ }
+ }
+ // Store the node which is needed when its dependencies are resolved.
+ item.node = node;
+ // Enqueue dependencies into the stack.
+ for (const PostorderDFSDep& dep : node.dependencies) {
+ VLOG(1) << "dep" << handle_to_instruction(dep.handle)->DebugString();
+ stack.emplace_back(dep.handle, dep.type, kUnvisited);
+ }
}
- TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeIsDynamicLiteral(handle));
- is_dynamic_[handle] = std::move(literal);
- return LiteralSlice(is_dynamic_[handle]);
+ VLOG(1) << "done" << evaluated[std::make_pair(handle, type)].ToString();
+ return evaluated[std::make_pair(handle, type)].Clone();
}
-StatusOr<LiteralSlice> ValueInference::AnalyzeConstant(int64 handle) {
- if (constant_.contains(handle)) {
- return LiteralSlice(constant_[handle]);
- }
- TF_ASSIGN_OR_RETURN(Literal literal, AnalyzeConstantLiteral(handle));
- constant_[handle] = std::move(literal);
- return LiteralSlice(constant_[handle]);
+StatusOr<Literal> ValueInference::AnalyzeIsDynamic(XlaOp op) {
+ PostorderDFSVisitor visitor(
+ [&](int64 handle) {
+ return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
+ },
+ [&](int64 handle) { return &(builder_->embedded_[handle]); });
+ return visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kValueIsDynamic);
}
-StatusOr<OptionaLiteralSlice> ValueInference::AnalyzeOptionalConstant(
- int64 handle) {
- TF_ASSIGN_OR_RETURN(LiteralSlice value, AnalyzeConstant(handle));
- TF_ASSIGN_OR_RETURN(LiteralSlice mask, AnalyzeIsDynamic(handle));
- return OptionaLiteralSlice(value, mask);
+StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
+ XlaOp op, ValueInferenceMode mode) {
+ PostorderDFSVisitor visitor(
+ [&](int64 handle) {
+ return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
+ },
+ [&](int64 handle) { return &(builder_->embedded_[handle]); });
+ switch (mode) {
+ case ValueInferenceMode::kLowerBound: {
+ TF_ASSIGN_OR_RETURN(
+ Literal value,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kConstantLowerBound));
+ TF_ASSIGN_OR_RETURN(
+ Literal mask,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kBoundIsDynamic));
+ return OptionalLiteral(std::move(value), std::move(mask));
+ }
+
+ case ValueInferenceMode::kUpperBound: {
+ TF_ASSIGN_OR_RETURN(
+ Literal value,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kConstantUpperBound));
+ TF_ASSIGN_OR_RETURN(
+ Literal mask,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kBoundIsDynamic));
+
+ return OptionalLiteral(std::move(value), std::move(mask));
+ }
+ case ValueInferenceMode::kValue: {
+ TF_ASSIGN_OR_RETURN(
+ Literal value,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kConstantValue));
+ TF_ASSIGN_OR_RETURN(
+ Literal mask,
+ visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kValueIsDynamic));
+ return OptionalLiteral(std::move(value), std::move(mask));
+ }
+ }
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/value_inference.h b/tensorflow/compiler/xla/client/value_inference.h
index afe8dfc..c7eb760 100644
--- a/tensorflow/compiler/xla/client/value_inference.h
+++ b/tensorflow/compiler/xla/client/value_inference.h
@@ -28,30 +28,49 @@
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
-// OptionaLiteralSlice is an augmented literal class which returns optional
-// values for each index (the value can be either valid or invalid). Underneath
-// it keeps two literals, a value literal, holding both the valid and garabage
-// value, and a masking litearl representing if a value is valid or garbage.
-class OptionaLiteralSlice {
+// OptionalLiteral is an augmented literal class which returns optional
+// values for each index (the value can be either valid or invalid). The
+// implementation keeps two literals, a value literal, holding both the valid
+// and garabage value, and a masking literal representing if a value is valid or
+// garbage.
+class OptionalLiteral {
public:
- explicit OptionaLiteralSlice(LiteralSlice value, LiteralSlice mask)
- : value_(value), mask_(mask) {}
+ explicit OptionalLiteral(Literal value, Literal mask)
+ : value_(std::move(value)), mask_(std::move(mask)) {}
template <typename NativeT>
- absl::optional<NativeT> Get(absl::Span<const int64> multi_index) const {
- if (mask_.Get<bool>(multi_index)) {
+ absl::optional<NativeT> Get(absl::Span<const int64> element_index,
+ ShapeIndex shape_index = {}) const {
+ if (mask_.Get<bool>(element_index, shape_index)) {
return absl::nullopt;
} else {
- return value_.Get<NativeT>(multi_index);
+ return value_.Get<NativeT>(element_index, shape_index);
}
}
// Returns true if all values in this literal slice are value.
bool AllValid() { return mask_.IsAll(0); }
+ // Get value out of this slice if all values are valid. Otherwise returns
+ // nullopt.
+ absl::optional<LiteralSlice> GetValue() {
+ if (!AllValid()) {
+ return absl::nullopt;
+ }
+ return LiteralSlice(value_);
+ }
+
private:
- LiteralSlice value_;
- LiteralSlice mask_;
+ Literal value_;
+ Literal mask_;
+};
+
+enum ValueInferenceMode {
+ // Inference the constant value itself.
+ kValue = 0,
+ // Inference upper-bound and lower-bound of the value. Bounds are inclusive.
+ kUpperBound,
+ kLowerBound,
};
class ValueInference {
@@ -61,38 +80,16 @@
// - What's the lower-bound of each value in a tensor.
// - What's the constant value of each tensor.
// - Whether or not each value in a tensor is dynamic.
- explicit ValueInference(XlaBuilder* builder) : builder_(builder) {}
- StatusOr<LiteralSlice> AnalyzeUpperBound(XlaOp op) {
- return Unimplemented("Analyzing upper-bound is not implemented yet.");
+ explicit ValueInference(XlaBuilder* builder) : builder_(builder) {
+ CHECK(builder_);
}
- StatusOr<LiteralSlice> AnalyzeLowerBound(XlaOp op) {
- return Unimplemented("Analyzing lower-bound is not implemented yet.");
- }
- StatusOr<LiteralSlice> AnalyzeIsDynamic(XlaOp op) {
- return AnalyzeIsDynamic(op.handle());
- }
-
- // Returns a OptionalConstant, the value is nullopt it's dynamic, otherwise a
- // concrete constant value.
- StatusOr<OptionaLiteralSlice> AnalyzeOptionalConstant(XlaOp op) {
- return AnalyzeOptionalConstant(op.handle());
- }
+ StatusOr<Literal> AnalyzeIsDynamic(XlaOp op);
+ // Returns an OptionalLiteral. Each individual value of the literal is
+ // the concrete constant value if it can be inferred, otherwise a nullopt.
+ StatusOr<OptionalLiteral> AnalyzeConstant(XlaOp op, ValueInferenceMode mode);
private:
- StatusOr<LiteralSlice> AnalyzeIsDynamic(int64 handle);
- StatusOr<LiteralSlice> AnalyzeConstant(int64 handle);
- StatusOr<OptionaLiteralSlice> AnalyzeOptionalConstant(int64 handle);
-
- StatusOr<Literal> AnalyzeIsDynamicLiteral(int64 handle);
- StatusOr<Literal> AnalyzeConstantLiteral(int64 handle);
-
XlaBuilder* builder_;
- // Cache to avoid re-evaluating. Mapping of xla handle to evaluated
- // literals.
- absl::flat_hash_map<int64, Literal> upper_bound_;
- absl::flat_hash_map<int64, Literal> lower_bound_;
- absl::flat_hash_map<int64, Literal> is_dynamic_;
- absl::flat_hash_map<int64, Literal> constant_;
HloEvaluator evaluator_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 43e5d41..6b95fdf 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -303,7 +303,7 @@
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
- "//third_party/eigen3",
+ "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 9816a2d..aa3150e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -52,6 +52,7 @@
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
@@ -367,6 +368,40 @@
return result;
}
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
+ HloOpcode opcode, const Literal& lhs, const Literal& rhs,
+ const Literal& ehs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.Clone());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.Clone());
+ std::unique_ptr<HloInstruction> ehs_instr =
+ HloInstruction::CreateConstant(ehs.Clone());
+ TF_ASSIGN_OR_RETURN(auto output_shape,
+ ShapeInference::InferTernaryOpShape(
+ opcode, lhs.shape(), rhs.shape(), ehs.shape()));
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateTernary(output_shape, opcode, lhs_instr.get(),
+ rhs_instr.get(), ehs_instr.get());
+ return Evaluate(cloned_instruction.get());
+}
+
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
+ ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.Clone());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.Clone());
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateCompare(
+ ShapeUtil::ChangeElementType(lhs.shape(), PRED), lhs_instr.get(),
+ rhs_instr.get(), direction);
+ auto result = Evaluate(cloned_instruction.get());
+
+ return result;
+}
+
StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 0d64582..a891312 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -124,6 +124,15 @@
StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseTernaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs,
+ const Literal& ehs);
+
+ StatusOr<Literal> EvaluateElementwiseCompareOp(ComparisonDirection direction,
+ const Literal& lhs,
+ const Literal& rhs);
+
StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config,
const Literal& lhs, const Literal& rhs);
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 822dec1..806d5bd 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2141,6 +2141,7 @@
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index 972b237..a5019fa 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -43,35 +43,17 @@
namespace xla {
namespace {
-// An enumerator for the client types that we want to iterate over in
-// the various tests.
-enum class ClientType { kLocal, kCompileOnly };
-
-class DynamismInferenceTest : public ::testing::Test {
+class ValueInferenceTest : public ::testing::Test {
public:
- explicit DynamismInferenceTest(se::Platform* platform = nullptr)
- : platform_(platform) {}
-
string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
+};
- Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
- if (client_type == ClientType::kLocal) {
- StatusOr<Client*> result =
- ClientLibrary::GetOrCreateLocalClient(platform);
- TF_CHECK_OK(result.status())
- << "could not create LocalClient for testing";
- return result.ValueOrDie();
- } else if (client_type == ClientType::kCompileOnly) {
- StatusOr<Client*> result =
- ClientLibrary::GetOrCreateCompileOnlyClient(platform);
- TF_CHECK_OK(result.status())
- << "could not create CompileOnlyClient for testing";
- return result.ValueOrDie();
- }
- LOG(FATAL) << "invalid client_type value";
- }
+class DynamismInferenceTest : public ValueInferenceTest {
+ public:
+ explicit DynamismInferenceTest(se::Platform* platform = nullptr)
+ : platform_(platform) {}
StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
Layout* output_layout = nullptr) {
@@ -316,5 +298,80 @@
EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
}
+class UpperBoundInferenceTest : public ValueInferenceTest {
+ public:
+ explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
+ : platform_(platform) {}
+
+ StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
+ XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
+ ValueInference value_inference(builder);
+ TF_ASSIGN_OR_RETURN(auto literal,
+ value_inference.AnalyzeConstant(
+ operand, ValueInferenceMode::kUpperBound));
+ return literal;
+ }
+
+ se::Platform* platform_;
+};
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
+ XlaBuilder b(TestName());
+ auto p =
+ Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+
+ auto gds0 = GetDimensionSize(p, 0);
+ auto gds1 = GetDimensionSize(p, 1);
+ auto tuple_2 = Tuple(&b, {gds0, gds1});
+ EXPECT_EQ(
+ ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {0}),
+ 2);
+ EXPECT_EQ(
+ ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {1}),
+ 3);
+}
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
+ XlaBuilder b(TestName());
+ auto p =
+ Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+
+ // The range of the first dimension is [0, 2]
+ auto gds0 = GetDimensionSize(p, 0);
+ // The range of the second dimension is [3, 3]
+ auto gds1 = GetDimensionSize(p, 1);
+ // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
+ auto sub = Sub(gds1, gds0);
+ EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
+}
+
+TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
+ XlaBuilder b(TestName());
+ auto p =
+ Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
+ // The range of the first dimension is [0, 2]
+ auto gds0 = GetDimensionSize(p, 0);
+ // The range of the second dimension is [3, 3]
+ auto gds1 = GetDimensionSize(p, 1);
+ // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
+ // don't use 0 as the lower bound as it would create divide-by-zero error.
+ auto sub = Div(gds1, gds0);
+ EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
+}
+
+TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
+ // We can infer a parameter's dimension's bound, but not the parameter value's
+ // bound.
+ XlaBuilder b(TestName());
+ auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
+ auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
+ auto gds = GetDimensionSize(p0, 0);
+ auto sub = Div(gds, p1);
+ EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
+ .ValueOrDie()
+ .Get<int32>({})
+ .has_value());
+}
+
} // namespace
} // namespace xla