[ValueInference] Better support for kCall and kRngBitGenerator
- Support call by looking inside the call instruction
- Support RngBitGenerator by conservatively considering it dynamic -- HloValuator doesn't support evaluating it.
PiperOrigin-RevId: 391838493
Change-Id: I16572fcb380595f1098471a23ce82afff0636552
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 59dfeeb..193d191 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -438,6 +438,8 @@
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
handle_to_instruction(handle));
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+ Shape subshape =
+ ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
PostorderDFSNode result;
// By default, the dependencies of current node are its operands.
for (auto operand_id : root->operand_ids()) {
@@ -452,7 +454,7 @@
case HloOpcode::kReduceScatter:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
- case HloOpcode::kCall:
+ case HloOpcode::kRngBitGenerator:
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
case HloOpcode::kSend:
@@ -467,12 +469,7 @@
return result.AddDependency(caller_operand, type, context)
.AddVisit([](Literal literal) { return literal; });
}
- return PostorderDFSNode().AddVisit([root, context](absl::Span<Literal>) {
- // The value is dynamic. We return a garbage literal here, which
- // will be masked out later.
- return CreateGarbageLiteral(
- ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
- });
+ return CreateAllDynamicResult(subshape, type);
}
// Subtract and Divide use lower-bound as second operand.
case HloOpcode::kSubtract:
@@ -486,6 +483,25 @@
"AnalyzeConstantValueFallback can't handle opcode: %s",
root->opcode());
}
+ case HloOpcode::kCall: {
+ auto node = PostorderDFSNode();
+ auto* call_proto = root;
+ if (call_proto->operand_ids_size() != 1) {
+ // Only support single operand forwarding.
+ return CreateAllDynamicResult(subshape, type);
+ }
+ int64_t called_root =
+ handle_to_computation(call_proto->called_computation_ids(0))
+ ->root_id();
+ InferenceContext call_context = context;
+ call_context.caller_operand_handles.push_back(call_proto->operand_ids(0));
+ node.AddDependency(called_root, PostorderDFSNodeType::kConstantValue,
+ call_context, "callee's root instruction");
+ return node.AddVisit([](Literal operand) -> StatusOr<Literal> {
+ // Forward result of callee's root to caller.
+ return std::move(operand);
+ });
+ }
case HloOpcode::kConditional: {
auto node = PostorderDFSNode();
@@ -1105,6 +1121,27 @@
.Evaluate();
});
}
+ case HloOpcode::kCall: {
+ auto node = PostorderDFSNode();
+ auto* call_proto = root;
+
+ if (call_proto->operand_ids_size() != 1) {
+ // Only support single operand forwarding.
+ return CreateAllDynamicResult(subshape, type);
+ }
+ int64_t call_root =
+ handle_to_computation(call_proto->called_computation_ids(0))
+ ->root_id();
+ InferenceContext branch_context = context;
+ branch_context.caller_operand_handles.push_back(
+ call_proto->operand_ids(0));
+ node.AddDependency(call_root, PostorderDFSNodeType::kValueIsDynamic,
+ branch_context, "callee's root instruction");
+ return node.AddVisit([context](Literal operand) -> StatusOr<Literal> {
+ // Forward result of callee's root to caller.
+ return operand;
+ });
+ }
case HloOpcode::kConditional: {
auto node = PostorderDFSNode();
auto* conditional_proto = root;
@@ -1363,7 +1400,6 @@
break;
}
- case HloOpcode::kCall:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
@@ -1671,12 +1707,12 @@
StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
XlaOp op, ValueInferenceMode mode) {
TF_RETURN_IF_ERROR(builder_->LookUpInstructionByHandle(op.handle()).status());
-
PostorderDFSVisitor visitor(
[&](int64_t handle) {
return builder_->LookUpInstructionByHandle(handle);
},
[&](int64_t handle) { return &(builder_->embedded_[handle]); });
+ TF_ASSIGN_OR_RETURN(Shape op_shape, builder_->GetShape(op));
int64_t handle = op.handle();
if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
@@ -1687,33 +1723,47 @@
}
switch (mode) {
case ValueInferenceMode::kLowerBound: {
+ TF_ASSIGN_OR_RETURN(Literal mask,
+ visitor.PostOrderDFSVisit(
+ handle, PostorderDFSNodeType::kBoundIsDynamic));
+ if (mask.IsAll(1)) {
+ // Everything is dynamic, no need to do constant inference.
+ return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+ }
TF_ASSIGN_OR_RETURN(
Literal value,
visitor.PostOrderDFSVisit(handle,
PostorderDFSNodeType::kConstantLowerBound));
- TF_ASSIGN_OR_RETURN(Literal mask,
- visitor.PostOrderDFSVisit(
- handle, PostorderDFSNodeType::kBoundIsDynamic));
+
return OptionalLiteral(std::move(value), std::move(mask));
}
case ValueInferenceMode::kUpperBound: {
+ TF_ASSIGN_OR_RETURN(Literal mask,
+ visitor.PostOrderDFSVisit(
+ handle, PostorderDFSNodeType::kBoundIsDynamic));
+ if (mask.IsAll(1)) {
+ // Everything is dynamic, no need to do constant inference.
+ return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+ }
TF_ASSIGN_OR_RETURN(
Literal value,
visitor.PostOrderDFSVisit(handle,
PostorderDFSNodeType::kConstantUpperBound));
- TF_ASSIGN_OR_RETURN(Literal mask,
- visitor.PostOrderDFSVisit(
- handle, PostorderDFSNodeType::kBoundIsDynamic));
return OptionalLiteral(std::move(value), std::move(mask));
}
case ValueInferenceMode::kValue: {
- TF_ASSIGN_OR_RETURN(Literal value,
- visitor.PostOrderDFSVisit(
- handle, PostorderDFSNodeType::kConstantValue));
TF_ASSIGN_OR_RETURN(Literal mask,
visitor.PostOrderDFSVisit(
handle, PostorderDFSNodeType::kValueIsDynamic));
+ if (mask.IsAll(1)) {
+ // Everything is dynamic, no need to do constant inference.
+ return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+ }
+ TF_ASSIGN_OR_RETURN(Literal value,
+ visitor.PostOrderDFSVisit(
+ handle, PostorderDFSNodeType::kConstantValue));
+
return OptionalLiteral(std::move(value), std::move(mask));
}
}
diff --git a/tensorflow/compiler/xla/tests/value_inference_test.cc b/tensorflow/compiler/xla/tests/value_inference_test.cc
index 55bd8b0..58f65ce 100644
--- a/tensorflow/compiler/xla/tests/value_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/value_inference_test.cc
@@ -372,6 +372,32 @@
EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
}
+TEST_F(DynamismInferenceTest, InferThroughCall) {
+ // The result of following call instruction is static.
+ //
+ // Callee:
+ // p = param
+ // return p
+ //
+ // Entry:
+ // c = constant(3)
+ // return call(c), callee
+ //
+ //
+
+ auto s32_shape = ShapeUtil::MakeShape(S32, {});
+ XlaBuilder call_builder("call");
+ Parameter(&call_builder, 0, s32_shape, "call_param");
+ auto call_computation = call_builder.Build().ValueOrDie();
+
+ XlaBuilder b(TestName());
+ auto constant = ConstantR0<int32>(&b, 3);
+ auto call = Call(&b, call_computation, {constant});
+ ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
+ // Result is static.
+ EXPECT_EQ(ComputeDynamismScalar(call, &b, {}).ValueOrDie(), false);
+}
+
TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
// The result of following conditional is dynamic.
// pred = .. # a dynamic value