[XLA] Use dynamism inference to infer dynamic dimensions for reshape.
- Introduce dynamism inference function in xla builder, which tells if a value is dynamic or static.
- Use dynamism inference to infer whether an input to reshape's dimensions is dynamic.
- This removes the "-1" hack I made before in the bridge, makes the code cleaner. Plus it can support more complex cases dynamic reshape when the dimension comes from a series of transformations.
PiperOrigin-RevId: 325532056
Change-Id: Icc5bad39a857be77537e4736dd6863b833e2fe9d
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index bf9a915..a85ba54 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -109,27 +109,33 @@
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
<< shape.DebugString() << ", unknown_index=" << unknown_index;
- shape_input.clear();
- // Run get input again, this time with dynamic dimension represented as
- // "-1"
- ctx->set_dynamic_dimension_is_minus_one(true);
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
-
int dynamic_dimension = -1;
-
- for (int d = 0; d < num_dims; ++d) {
- const int32 size = shape_input[d];
- if (size == -1) {
- if (dynamic_dimension == -1) {
+ if (ctx->InputXlaShape(0)->is_dynamic()) {
+ std::vector<bool> dynamic_dims;
+ OP_REQUIRES_OK(ctx,
+ ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
+ for (int d = 0; d < num_dims; ++d) {
+ const bool dim_is_dynamic = dynamic_dims[d];
+ if (dim_is_dynamic) {
dynamic_dimension = d;
- } else {
- if (unknown_index != d) {
- dynamic_dimension = d;
- }
}
}
- }
+ // When reshaping from dynamic dimension, unkwown index is considered
+ // dynamic. E.g.,
+ // [<=10]
+ // |
+ // Reshape
+ // |
+ // [2, -1]
+ // The second dimension is dynamic.
+ if (dynamic_dimension == -1) {
+ dynamic_dimension = unknown_index;
+ }
+ VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
+ << xla::VectorString(shape.dim_sizes())
+ << ", dynamic_dim=" << dynamic_dimension;
+ }
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
// in XLA to know which output dimension is dynamic.
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 34e108b..f0cc8d2 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -101,6 +101,48 @@
});
}
+xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
+ xla::Client* client) const {
+ switch (kind()) {
+ case Kind::kConstant: {
+ // Constant values are considered static.
+ Tensor constant_false(DT_BOOL, constant_value().shape());
+ auto flat = constant_false.flat<bool>();
+ for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
+ return constant_false;
+ }
+ case Kind::kXlaOp:
+ break;
+ case Kind::kTensorList:
+ TF_FALLTHROUGH_INTENDED;
+ case Kind::kResource:
+ TF_FALLTHROUGH_INTENDED;
+ case Kind::kInvalid:
+ return errors::InvalidArgument(
+ "ResolveDynamism called on unsupported XlaExpression: ",
+ HumanString());
+ }
+
+ if (!client)
+ return errors::InvalidArgument("client is required to resolve constant");
+
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
+ handle().builder()->BuildDynamicInferenceGraph(handle()));
+
+ TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+
+ // The XLA layout is specified minor to major, and TensorFlow uses a major to
+ // minor order.
+ std::vector<int64> layout_indices(shape.dims());
+ std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+ xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
+ TF_ASSIGN_OR_RETURN(xla::Literal literal,
+ client->ComputeConstant(constant_graph, &layout));
+ Tensor tensor(DT_BOOL);
+ TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
+ return tensor;
+}
+
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
switch (kind()) {
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 3010964..3546368 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -99,6 +99,10 @@
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
+ // ResolveDynamism computes where a value inside this op is dynamic or can be
+ // inferred at compile time.
+ xla::StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
+
// Returns the shape of the tensor.
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
// not the shape of the resource's value.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 735a6c7..0753754 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -243,6 +243,48 @@
return LiteralToFloat64Scalar(literal, out);
}
+static Status LiteralToPredVector(const xla::LiteralSlice& literal,
+ std::vector<bool>* out) {
+ if (literal.shape().rank() != 1) {
+ return errors::InvalidArgument("value is not 1D, rank: ",
+ literal.shape().rank());
+ }
+ int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
+ if (literal.shape().element_type() != xla::PRED) {
+ return errors::InvalidArgument("value is not PRED");
+ }
+ for (int64 i = 0; i < size; ++i) {
+ out->push_back(literal.Get<bool>({i}));
+ }
+ return Status::OK();
+}
+
+Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
+ int index, std::vector<bool>* out) {
+ xla::Literal literal;
+ XlaExpression e = InputExpression(index);
+ auto* client = compiler() ? compiler()->client() : nullptr;
+ xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
+ if (!dynamism_or_status.ok()) {
+ Status status = dynamism_or_status.status();
+ errors::AppendToMessage(&status, "while evaluating input dynamism", index,
+ " of ", context_->op_kernel().type_string());
+ return status;
+ }
+ Tensor dynamism = dynamism_or_status.ValueOrDie();
+
+ Tensor temp(dynamism.dtype());
+ TensorShape tensor_shape({InputShape(index).num_elements()});
+ if (!temp.CopyFrom(dynamism, tensor_shape)) {
+ return errors::InvalidArgument(
+ context_->op_kernel().name(), " input ", index, " has shape ",
+ dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
+ }
+
+ TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
+ return LiteralToPredVector(literal, out);
+}
+
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64>* out) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3cf51e6..75c3e60 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -116,6 +116,9 @@
// returns a one-element list.
Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes);
+ // Evaluates input and returns their dynamism vector in a vector of
+ // predicates.
+ Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
// Helper methods for constant inputs.
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 484fb0a..8de8216 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -32,6 +32,7 @@
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -39,6 +40,7 @@
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/errors.h"
namespace xla {
@@ -71,6 +73,52 @@
entry->set_id(id);
entry->set_name(GetFullName(base_name, separator, id));
}
+
+ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) {
+ return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
+}
+
+HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
+ bool pred) {
+ HloInstructionProto const_instr;
+ Literal literal = LiteralUtil::CreateR0(pred);
+ Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
+ *const_instr.mutable_shape() = shape.ToProto();
+ *const_instr.mutable_literal() = literal_broadcast.ToProto();
+ *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
+ const_instr.set_id(id);
+ return const_instr;
+}
+
+// Converts a HloComputation into ReducerOr with predicate types.
+HloComputationProto CreateReduceOr(int64 reducer_id,
+ HloComputationProto* original_reducer) {
+ HloComputationProto reducer;
+ SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id);
+ std::vector<int64> operands_id;
+ for (auto& inst : original_reducer->instructions()) {
+ // Copy params.
+ if (StringToHloOpcode(inst.opcode()).ValueOrDie() ==
+ HloOpcode::kParameter) {
+ HloInstructionProto* new_param = reducer.add_instructions();
+ *new_param = inst;
+ *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+ operands_id.push_back(inst.id());
+ }
+ if (inst.id() == original_reducer->root_id()) {
+ HloInstructionProto* new_root = reducer.add_instructions();
+ *new_root = inst;
+ *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape());
+ *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+ new_root->clear_operand_ids();
+ for (int64 operand_id : operands_id) {
+ new_root->add_operand_ids(operand_id);
+ }
+ reducer.set_root_id(inst.id());
+ }
+ }
+ return reducer;
+}
} // namespace
namespace internal {
@@ -2842,6 +2890,196 @@
return is_constant;
}
+StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
+ LookUpInstruction(root_op));
+
+ HloComputationProto entry;
+ SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator,
+ GetNextId());
+ ProgramShapeProto* program_shape = entry.mutable_program_shape();
+ *program_shape->mutable_result() =
+ ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
+
+ std::set<int64> seen;
+ struct WorkItem {
+ explicit WorkItem(int64 handle, bool need_rewrite)
+ : handle(handle), need_rewrite(need_rewrite) {}
+ int64 handle;
+ // If need_rewrite is true, the instruction will be copied and rewrite into
+ // a pred instruction indicating if each value is dynamic. If need_rewrite
+ // is false, simply copy the instruction to the output graph.
+ // E.g.,
+ // For select(P, A, B), we need to rewrite A and B into predicates, but
+ // don't need to rewrite P.
+ bool need_rewrite;
+ };
+ std::queue<WorkItem> worklist;
+ worklist.push(WorkItem(root->id(), true));
+ entry.set_root_id(root->id());
+ std::vector<HloComputationProto> called_computatons;
+ // Rewritre instruction with id "from" into the new graph.
+ // Returns more work items that need to finish.
+ auto rewrite_instruction =
+ [&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
+ // Rewrite the instruction with following rules:
+ // - Unary ops: Convert into bitcast (identity) with type Pred.
+ // - Binary ops: Convert into binary or.
+ // - Select: Convert into binary or with its two data operands.
+ // - Concat / Tuple/ GTE / Bitcast: Copy.
+ // - Param: Convert to constant True.
+ // - GetDimensionSize: Convert to constant True if dimension is dynamic,
+ // contant False if dimension is static.
+ // - Reduce: Convert to reduce or.
+ // - Constant: Convert to constant False.
+ // - Other ops: Not supported.
+ // Create the instruction for the new handle.
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+ LookUpInstructionByHandle(from));
+
+ TF_ASSIGN_OR_RETURN(HloOpcode opcode,
+ StringToHloOpcode(instr_proto->opcode()));
+ std::vector<WorkItem> operands_todo;
+ auto* new_instr = entry.add_instructions();
+ *new_instr = *instr_proto;
+ for (auto operand_id : new_instr->operand_ids()) {
+ operands_todo.emplace_back(operand_id, need_rewrite);
+ }
+
+ if (!need_rewrite) {
+ *new_instr->mutable_name() =
+ GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
+ return operands_todo;
+ }
+ *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
+ Shape new_shape(new_instr->shape());
+ switch (opcode) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCeil:
+ case HloOpcode::kCollectivePermuteDone:
+ case HloOpcode::kCos:
+ case HloOpcode::kClz:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kFloor:
+ case HloOpcode::kImag:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kNot:
+ case HloOpcode::kNegate:
+ case HloOpcode::kPopulationCount:
+ case HloOpcode::kReal:
+ case HloOpcode::kRsqrt:
+ case HloOpcode::kLogistic:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kConvert:
+ case HloOpcode::kSqrt:
+ case HloOpcode::kCbrt:
+ case HloOpcode::kTanh:
+ CHECK_EQ(instr_proto->operand_ids_size(), 1);
+ *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast);
+ break;
+ case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kDivide:
+ case HloOpcode::kComplex:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kCompare:
+ case HloOpcode::kAnd:
+ case HloOpcode::kOr:
+ case HloOpcode::kXor:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ CHECK_EQ(instr_proto->operand_ids_size(), 2);
+ *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
+ break;
+ case HloOpcode::kSelect:
+ operands_todo[0].need_rewrite = false;
+ break;
+ case HloOpcode::kGather:
+ operands_todo[1].need_rewrite = false;
+ break;
+ case HloOpcode::kReduce: {
+ int64 reducer_id = new_instr->called_computation_ids(0);
+ called_computatons.push_back(
+ CreateReduceOr(reducer_id, &embedded_[reducer_id]));
+ break;
+ }
+ case HloOpcode::kTuple:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kSlice:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReshape:
+ break;
+ case HloOpcode::kGetDimensionSize: {
+ int64 dimension = instr_proto->dimensions(0);
+ int64 operand_handle = instr_proto->operand_ids(0);
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
+ LookUpInstructionByHandle(operand_handle));
+
+ *new_instr = CreateConstantInstruction(
+ from, new_shape,
+ operand_proto->shape().is_dynamic_dimension(dimension));
+ operands_todo.clear();
+ break;
+ }
+ case HloOpcode::kConstant:
+ *new_instr = CreateConstantInstruction(from, new_shape, false);
+ break;
+ case HloOpcode::kParameter:
+ *new_instr = CreateConstantInstruction(from, new_shape, true);
+ break;
+ default:
+ return InvalidArgument("Dynamic inferencing %s is not supported",
+ instr_proto->DebugString());
+ }
+ *new_instr->mutable_name() =
+ GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
+ return operands_todo;
+ };
+
+ while (!worklist.empty()) {
+ WorkItem item = worklist.front();
+ worklist.pop();
+ if (!seen.insert(item.handle).second) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(auto todos,
+ rewrite_instruction(item.handle, item.need_rewrite));
+ for (WorkItem& todo : todos) {
+ worklist.push(todo);
+ }
+ }
+ absl::c_sort(*entry.mutable_instructions(),
+ [](const HloInstructionProto& p1,
+ const HloInstructionProto& p2) { return p1.id() < p2.id(); });
+ XlaComputation computation(entry.id());
+ HloModuleProto* module = computation.mutable_proto();
+ module->set_name(entry.name());
+ module->set_id(entry.id());
+ module->set_entry_computation_name(entry.name());
+ module->set_entry_computation_id(entry.id());
+ *module->mutable_host_program_shape() = *program_shape;
+ for (auto& called_comp : called_computatons) {
+ *module->add_computations() = called_comp;
+ }
+ *module->add_computations() = std::move(entry);
+ XLA_VLOG_LINES(3, module->DebugString());
+ return std::move(computation);
+}
+
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index aa5074d..6753b6d 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -278,6 +278,31 @@
StatusOr<XlaComputation> BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
+ // Similar to BuildConstantSubGraph, but with root element type changed to
+ // boolean. A true value in the root indicates that the value is dynamic while
+ // false value indicates that the value is a constant. This will copy the
+ // needed ops/computations to the subgraph.
+ //
+ // E.g.,
+ // Compuptation {
+ // a = 3
+ // b = param(0)
+ // ROOT Tuple(a + b, a + 1, b + 1)
+ // }
+ // Calling BuildDynamicInferenceGraph on root will produce the following
+ // graph:
+ //
+ // Compuptation {
+ // a = False
+ // b = True
+ // ROOT Tuple(a | b, a, b)
+ // }
+ //
+ // The result, which is (True, False, True) after evaluation, can be
+ // interpreted as "First element is dynamic; Second element is static; Third
+ // element is dynamic".
+ StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
+
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// XlaOp and inform the user of the error that occurred while
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 2f24568..36429d3 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -805,7 +805,8 @@
}
if (input_dim_size > output_dim_size) {
- TF_RET_CHECK(input_dim_size % output_dim_size == 0);
+ TF_RET_CHECK(input_dim_size % output_dim_size == 0)
+ << reshape->ToString();
const int64 divisor = input_dim_size / output_dim_size;
HloInstruction* divisor_hlo =
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 02fcaaf..0833919 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -783,9 +783,18 @@
/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
PrimitiveType type) {
- Shape new_shape = original;
- new_shape.set_element_type(type);
- return new_shape;
+ if (original.IsTuple()) {
+ std::vector<Shape> new_operands;
+ new_operands.reserve(original.tuple_shapes_size());
+ for (const Shape& operand : original.tuple_shapes()) {
+ new_operands.push_back(ChangeElementType(operand, type));
+ }
+ return MakeTupleShape(new_operands);
+ } else {
+ Shape new_shape = original;
+ new_shape.set_element_type(type);
+ return new_shape;
+ }
}
/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 927f9d1..17444c0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2089,6 +2089,31 @@
)
xla_test(
+ name = "dynamism_inference_test",
+ srcs = ["dynamism_inference_test.cc"],
+ deps = [
+ ":test_macros_header",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/client/lib:prng",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+xla_test(
name = "compute_constant_test",
srcs = ["compute_constant_test.cc"],
deps = [
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
new file mode 100644
index 0000000..ba4092d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/match.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/prng.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// An enumerator for the client types that we want to iterate over in
+// the various tests.
+enum class ClientType { kLocal, kCompileOnly };
+ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
+
+class DynamismInferenceTest : public ::testing::Test {
+ public:
+ explicit DynamismInferenceTest(se::Platform* platform = nullptr)
+ : platform_(platform) {}
+
+ string TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ Client* ClientOrDie(se::Platform* platform, ClientType client_type) {
+ if (client_type == ClientType::kLocal) {
+ StatusOr<Client*> result =
+ ClientLibrary::GetOrCreateLocalClient(platform);
+ TF_CHECK_OK(result.status())
+ << "could not create LocalClient for testing";
+ return result.ValueOrDie();
+ } else if (client_type == ClientType::kCompileOnly) {
+ StatusOr<Client*> result =
+ ClientLibrary::GetOrCreateCompileOnlyClient(platform);
+ TF_CHECK_OK(result.status())
+ << "could not create CompileOnlyClient for testing";
+ return result.ValueOrDie();
+ }
+ LOG(FATAL) << "invalid client_type value";
+ }
+
+ StatusOr<Literal> ComputeDynamismLiteral(Client* client, XlaOp operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
+ TF_ASSIGN_OR_RETURN(auto subgraph,
+ builder->BuildDynamicInferenceGraph(operand));
+ TF_ASSIGN_OR_RETURN(auto computed,
+ client->ComputeConstant(subgraph, output_layout));
+ return std::move(computed);
+ }
+
+ StatusOr<bool> ComputeDynamismScalar(Client* client, XlaOp operand,
+ XlaBuilder* builder,
+ ShapeIndex index = {}) {
+ TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand,
+ builder, nullptr));
+ return literal.Get<bool>({}, index);
+ }
+
+ se::Platform* platform_;
+};
+
+TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto computation = ConstantR0<int32>(&b, 42);
+
+ auto value = ComputeDynamismScalar(client, computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ // A constant is not dynamic.
+ EXPECT_EQ(value.ValueOrDie(), false);
+ }
+}
+
+TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto c = ConstantR0<int32>(&b, 42);
+ auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+ auto tuple = Tuple(&b, {c, p});
+ auto gte0 = GetTupleElement(tuple, 0);
+ auto gte1 = GetTupleElement(tuple, 1);
+ auto tuple_2 = Tuple(&b, {gte0, gte1});
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+ false);
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+ true);
+ }
+}
+
+TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto c = ConstantR0<int32>(&b, 42);
+ auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+ auto concat = ConcatScalars(&b, {c, p});
+ auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
+ auto reshape0 = Reshape(slice0, {});
+ auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
+ auto reshape1 = Reshape(slice1, {});
+ auto tuple_2 = Tuple(&b, {reshape0, reshape1});
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+ false);
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+ true);
+ }
+}
+
+TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+ auto value = ComputeDynamismScalar(client, computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ // A parameter is considered dynamic.
+ EXPECT_EQ(value.ValueOrDie(), true);
+ }
+}
+
+TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto c = ConstantR0<int32>(&b, 42);
+ auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+ auto neg0 = Neg(c);
+ auto neg1 = Neg(p);
+ auto tuple_2 = Tuple(&b, {neg0, neg1});
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+ false);
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+ true);
+ }
+}
+
+TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ auto c = ConstantR0<int32>(&b, 42);
+ auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+
+ // Static value + static value = static
+ auto add1 = Add(c, c);
+ // Dynamic value + dynamic value = dynamic
+ auto add2 = Add(p, c);
+ auto tuple_2 = Tuple(&b, {add1, add2});
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+ false);
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+ true);
+ }
+}
+
+TEST_F(DynamismInferenceTest, GetDimensionSize) {
+ for (ClientType client_type : client_types) {
+ Client* client = ClientOrDie(platform_, client_type);
+ XlaBuilder b(TestName());
+ // param = Param([<=2, 3])
+ // get_dimension_size(param, 0) is dynamic
+ // get_dimension_size(param, 1) is static
+ auto p =
+ Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
+
+ auto gds0 = GetDimensionSize(p, 0);
+ auto gds1 = GetDimensionSize(p, 1);
+ auto tuple_2 = Tuple(&b, {gds0, gds1});
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(),
+ true);
+ EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(),
+ false);
+ }
+}
+
+} // namespace
+} // namespace xla