Allow value inference to analyze into embedded conditional computations.
PiperOrigin-RevId: 369016974
Change-Id: I8c37bedb8a2621b3a65c6270a01b585f7b33a2d6
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 3271347..7dd2395 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -314,9 +314,8 @@
case HloOpcode::kAbs:
case HloOpcode::kDivide:
case HloOpcode::kGetDimensionSize: {
- return InvalidArgument(
- "AnalyzeConstantValueFallback can't handle opcode: %s",
- root->opcode());
+ return InvalidArgument("AnalyzeConstantValue can't handle opcode: %s",
+ root->opcode());
}
case HloOpcode::kGetTupleElement: {
int64 operand_handle = root->operand_ids(0);
@@ -332,52 +331,6 @@
});
}
- if (operand_opcode == HloOpcode::kConditional) {
- int64 index = root->tuple_index();
- auto node = PostorderDFSNode();
- auto* conditional_proto = operand_proto;
- // Add dependencies to analyze the predicate of the conditional.
- node.AddDependency(conditional_proto->operand_ids(0),
- PostorderDFSNodeType::kConstantValue)
- .AddDependency(conditional_proto->operand_ids(0),
- PostorderDFSNodeType::kValueIsDynamic);
- const int64 branch_size =
- conditional_proto->called_computation_ids_size();
- for (int64 i = 0; i < branch_size; ++i) {
- int64 branch_root = handle_to_computation(
- conditional_proto->called_computation_ids(i))
- ->root_id();
- int64 branch_root_operand =
- handle_to_instruction(branch_root)->operand_ids(index);
- node.AddDependency(branch_root_operand,
- PostorderDFSNodeType::kConstantValue);
- }
- return node.AddVisit(
- [](absl::Span<Literal> operands) -> StatusOr<Literal> {
- int64 pred_is_dynamic = operands[1].Get<bool>({});
- if (pred_is_dynamic) {
- // If predicate is dynamic, return the value of the first branch
- // -- if all branches return the same value, this is the value
- // that we want. If not, the value will be masked anyway so the
- // value inside doesn't matter.
- return std::move(operands[2]);
- } else {
- // If predicate is static, return the value of given branch.
- int64 branch_index = 0;
- if (operands[0].shape().element_type() == PRED) {
- if (operands[0].Get<bool>({})) {
- branch_index = 0;
- } else {
- branch_index = 1;
- }
- } else {
- branch_index = operands[0].GetIntegralAsS64({}).value();
- }
- const int64 branch_dynamism_index = 2 + branch_index;
- return std::move(operands[branch_dynamism_index]);
- }
- });
- }
return result.AddVisit([root](absl::Span<Literal> operands) {
return HloProtoEvaluator(*root)
.WithOperands(operands)
@@ -624,7 +577,6 @@
const HloInstructionProto* root = handle_to_instruction(handle);
// Invariant check.
TF_RET_CHECK(root);
- VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
PostorderDFSNode result;
for (auto operand_id : root->operand_ids()) {
@@ -722,88 +674,11 @@
TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
StringToHloOpcode(operand_proto->opcode()));
if (operand_opcode == HloOpcode::kParameter) {
- return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
- // As an optimization, don't materialize the whole parameter if it's
- // followed by a GTE.
+ PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
+ // Don't materialize the whole parameter if it's followed by a GTE.
return CreatePredLiteral(true, Shape(root->shape()));
});
}
-
- if (operand_opcode == HloOpcode::kConditional) {
- int64 index = root->tuple_index();
- auto* conditional_proto = operand_proto;
- auto node = PostorderDFSNode();
- // Add dependencies to analyze the predicate of the conditional.
- node.AddDependency(conditional_proto->operand_ids(0),
- PostorderDFSNodeType::kConstantValue)
- .AddDependency(conditional_proto->operand_ids(0),
- PostorderDFSNodeType::kValueIsDynamic);
- const int64 branch_size =
- conditional_proto->called_computation_ids_size();
- for (int64 i = 0; i < branch_size; ++i) {
- int64 branch_root = handle_to_computation(
- conditional_proto->called_computation_ids(i))
- ->root_id();
- int64 branch_root_operand =
- handle_to_instruction(branch_root)->operand_ids(index);
- node.AddDependency(branch_root_operand,
- PostorderDFSNodeType::kConstantValue)
- .AddDependency(branch_root_operand,
- PostorderDFSNodeType::kValueIsDynamic);
- }
- // Predicate uses 2 dependencies:
- // 0: Predicate value.
- // 1: Predicate is dynamic.
- // Each branch i has 2 dependenices:
- // 2*i: Branch result value
- // 2*i + 1: Branch value is dynamic.
- return node.AddVisit([root, branch_size](absl::Span<Literal> operands)
- -> StatusOr<Literal> {
- int64 pred_is_dynamic = operands[1].Get<bool>({});
- auto result = CreatePredLiteral(true, Shape(root->shape()));
- if (pred_is_dynamic) {
- // If predicate is dynamic, the result is only static if all
- // branches are static and return the same value.
- auto result = CreatePredLiteral(true, Shape(root->shape()));
-
- result.MutableEachCell<bool>(
- [&](absl::Span<const int64> indices, bool value) {
- string branch_value = operands[2].GetAsString(indices, {});
- for (int64 i = 0; i < branch_size; ++i) {
- const int64 branch_value_index = 2 + 2 * i;
- const int64 branch_dynamism_index = 2 + 2 * i + 1;
- auto branch_is_dynamic =
- operands[branch_dynamism_index].Get<bool>(indices);
- if (branch_is_dynamic) {
- return true;
- }
- if (branch_value !=
- operands[branch_value_index].GetAsString(indices, {})) {
- return true;
- }
- }
- // Value of the branch is static.
- return false;
- });
- return result;
- } else {
- // If predicate is static, return true if given branch result
- // value is dynamic.
- int64 branch_index = 0;
- if (operands[0].shape().element_type() == PRED) {
- if (operands[0].Get<bool>({})) {
- branch_index = 0;
- } else {
- branch_index = 1;
- }
- } else {
- branch_index = operands[0].GetIntegralAsS64({}).value();
- }
- const int64 branch_dynamism_index = 2 + 2 * branch_index + 1;
- return std::move(operands[branch_dynamism_index]);
- }
- });
- }
return result.AddVisit([root](absl::Span<Literal> operands) {
return HloProtoEvaluator(*root)
.WithOperands(operands)
@@ -923,7 +798,7 @@
break;
}
default:
- return Unimplemented("Can't infer dynamism through %s: %s",
+ return Unimplemented("Can't infer upper bound through %s: %s",
root->opcode(), root->DebugString());
}
}
@@ -977,23 +852,19 @@
PostorderDFSNode node;
switch (item.type) {
case PostorderDFSNodeType::kConstantValue: {
- VLOG(1) << "constant value";
- TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
+ TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
break;
}
case PostorderDFSNodeType::kConstantLowerBound: {
- VLOG(1) << "constant lower bound";
- TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
+ TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
break;
}
case PostorderDFSNodeType::kConstantUpperBound: {
- VLOG(1) << "constant upper bound";
TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle));
break;
}
case PostorderDFSNodeType::kBoundIsDynamic:
case PostorderDFSNodeType::kValueIsDynamic: {
- VLOG(1) << "value is dynamic";
TF_ASSIGN_OR_RETURN(node, AnalyzeIsDynamic(item.handle, item.type));
break;
}
@@ -1016,9 +887,8 @@
return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
},
[&](int64 handle) { return &(builder_->embedded_[handle]); });
- auto result = visitor.PostOrderDFSVisit(
- op.handle(), PostorderDFSNodeType::kValueIsDynamic);
- return result;
+ return visitor.PostOrderDFSVisit(op.handle(),
+ PostorderDFSNodeType::kValueIsDynamic);
}
StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5f5aba6..9d249c1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -3767,15 +3767,6 @@
}
int64 computation_id = imported_computation.id();
- for (int64 i = 0; i < imported_computation.instructions_size(); ++i) {
- ImportedInstruction imported_instruction;
- imported_instruction.computation_id = computation_id;
- imported_instruction.instruction_index = i;
- handle_to_imported_index_.insert(
- {imported_computation.instructions(i).id(),
- ImportedInstruction{.computation_id = computation_id,
- .instruction_index = i}});
- }
embedded_.insert({computation_id, std::move(imported_computation)});
}
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 6926fbc..df8ba65 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -1003,15 +1003,6 @@
// instruction is held.
absl::flat_hash_map<int64, int64> handle_to_index_;
- // Track imported instructions by their computation id and the position in
- // their computation's instruction list.
- struct ImportedInstruction {
- int64 computation_id;
- int64 instruction_index;
- };
-
- absl::flat_hash_map<int64, ImportedInstruction> handle_to_imported_index_;
-
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
@@ -1454,14 +1445,6 @@
int64 handle) const {
auto it = handle_to_index_.find(handle);
if (it == handle_to_index_.end()) {
- // Try look for the instruction in the imported instructions.
- auto imported_it = handle_to_imported_index_.find(handle);
- if (imported_it != handle_to_imported_index_.end()) {
- ImportedInstruction imported = imported_it->second;
- return const_cast<InstructionType>(
- &embedded_.at(imported.computation_id)
- .instructions(imported.instruction_index));
- }
return InvalidArgument("No XlaOp with handle %d", handle);
}
return const_cast<InstructionType>(&instructions_.at(it->second));
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index 69c8083..a5019fa 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -298,139 +298,6 @@
EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
}
-TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
- // The result of following conditional is static.
- // pred = .. # a dynamic value
- // if (pred) {
- // return (1) # both branches return the same value
- // } else {
- // return (1)
- // }
- //
-
- auto s32_shape = ShapeUtil::MakeShape(S32, {});
- auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
- XlaBuilder true_builder("true");
- Parameter(&true_builder, 0, s32_shape, "cond_param");
- Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
- auto true_computation = true_builder.Build().ValueOrDie();
-
- XlaBuilder false_builder("false");
- Parameter(&false_builder, 0, s32_shape, "cond_param");
- Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 1)});
- auto false_computation = false_builder.Build().ValueOrDie();
-
- XlaBuilder b(TestName());
- auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
- auto constant = ConstantR0<int32>(&b, 0);
- auto cond = Conditional(parameter, constant, true_computation, constant,
- false_computation);
- auto gte = GetTupleElement(cond, 0);
- ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
- // Result is not dynamic.
- EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
- // The result of following conditional is dynamic.
- // pred = .. # a dynamic value
- // if (pred) {
- // return (1) # These two branches return different values.
- // } else {
- // return (2)
- // }
- //
-
- auto s32_shape = ShapeUtil::MakeShape(S32, {});
- auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
- XlaBuilder true_builder("true");
- Parameter(&true_builder, 0, s32_shape, "cond_param");
- Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
- auto true_computation = true_builder.Build().ValueOrDie();
-
- XlaBuilder false_builder("false");
- Parameter(&false_builder, 0, s32_shape, "cond_param");
- Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 2)});
- auto false_computation = false_builder.Build().ValueOrDie();
-
- XlaBuilder b(TestName());
- auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
- auto constant = ConstantR0<int32>(&b, 0);
- auto cond = Conditional(parameter, constant, true_computation, constant,
- false_computation);
- auto gte = GetTupleElement(cond, 0);
- ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
- // Result is dynamic.
- EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
- // The result of following conditional is static.
- // pred = true
- // if (pred) {
- // return (1)
- // } else {
- // return (..dynamic_value...)
- // }
- //
-
- auto s32_shape = ShapeUtil::MakeShape(S32, {});
- auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
- XlaBuilder true_builder("true");
- Parameter(&true_builder, 0, s32_shape, "cond_param");
- Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
- auto true_computation = true_builder.Build().ValueOrDie();
-
- XlaBuilder false_builder("false");
- Tuple(&false_builder,
- {Parameter(&false_builder, 0, s32_shape, "cond_param")});
- auto false_computation = false_builder.Build().ValueOrDie();
-
- XlaBuilder b(TestName());
- auto pred = ConstantR0<bool>(&b, true);
- auto constant = ConstantR0<int32>(&b, 0);
- auto cond = Conditional(pred, constant, true_computation, constant,
- false_computation);
- auto gte = GetTupleElement(cond, 0);
- ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
- // Result is not dynamic.
- EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest,
- InferThroughConditionalPredIsConstantFalseBranch) {
- // The result of following conditional is dynamic.
- // pred = false
- // if (pred) {
- // return (1)
- // } else {
- // return (..dynamic_value...)
- // }
- //
-
- auto s32_shape = ShapeUtil::MakeShape(S32, {});
- auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
- XlaBuilder true_builder("true");
- Parameter(&true_builder, 0, s32_shape, "cond_param");
- Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
- auto true_computation = true_builder.Build().ValueOrDie();
-
- XlaBuilder false_builder("false");
- Tuple(&false_builder,
- {Parameter(&false_builder, 0, s32_shape, "cond_param")});
- auto false_computation = false_builder.Build().ValueOrDie();
-
- XlaBuilder b(TestName());
- auto pred = ConstantR0<bool>(&b, false);
- auto constant = ConstantR0<int32>(&b, 0);
- auto cond = Conditional(pred, constant, true_computation, constant,
- false_computation);
- auto gte = GetTupleElement(cond, 0);
- ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
- // Result is dynamic.
- EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
class UpperBoundInferenceTest : public ValueInferenceTest {
public:
explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)