[ValueInference] Better support for kCall and kRngBitGenerator

- Support call by looking inside the call instruction
- Support RngBitGenerator by conservatively considering it dynamic -- HloValuator doesn't support evaluating it.

PiperOrigin-RevId: 391838493
Change-Id: I16572fcb380595f1098471a23ce82afff0636552
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 59dfeeb..193d191 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -438,6 +438,8 @@
   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
                       handle_to_instruction(handle));
   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+  Shape subshape =
+      ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index);
   PostorderDFSNode result;
   // By default, the dependencies of current node are its operands.
   for (auto operand_id : root->operand_ids()) {
@@ -452,7 +454,7 @@
     case HloOpcode::kReduceScatter:
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
-    case HloOpcode::kCall:
+    case HloOpcode::kRngBitGenerator:
     case HloOpcode::kCustomCall:
     case HloOpcode::kWhile:
     case HloOpcode::kSend:
@@ -467,12 +469,7 @@
         return result.AddDependency(caller_operand, type, context)
             .AddVisit([](Literal literal) { return literal; });
       }
-      return PostorderDFSNode().AddVisit([root, context](absl::Span<Literal>) {
-        // The value is dynamic. We return a garbage literal here, which
-        // will be masked out later.
-        return CreateGarbageLiteral(
-            ShapeUtil::GetSubshape(Shape(root->shape()), context.shape_index));
-      });
+      return CreateAllDynamicResult(subshape, type);
     }
     // Subtract and Divide use lower-bound as second operand.
     case HloOpcode::kSubtract:
@@ -486,6 +483,25 @@
           "AnalyzeConstantValueFallback can't handle opcode: %s",
           root->opcode());
     }
+    case HloOpcode::kCall: {
+      auto node = PostorderDFSNode();
+      auto* call_proto = root;
+      if (call_proto->operand_ids_size() != 1) {
+        // Only support single operand forwarding.
+        return CreateAllDynamicResult(subshape, type);
+      }
+      int64_t called_root =
+          handle_to_computation(call_proto->called_computation_ids(0))
+              ->root_id();
+      InferenceContext call_context = context;
+      call_context.caller_operand_handles.push_back(call_proto->operand_ids(0));
+      node.AddDependency(called_root, PostorderDFSNodeType::kConstantValue,
+                         call_context, "callee's root instruction");
+      return node.AddVisit([](Literal operand) -> StatusOr<Literal> {
+        // Forward result of callee's root to caller.
+        return std::move(operand);
+      });
+    }
 
     case HloOpcode::kConditional: {
       auto node = PostorderDFSNode();
@@ -1105,6 +1121,27 @@
             .Evaluate();
       });
     }
+    case HloOpcode::kCall: {
+      auto node = PostorderDFSNode();
+      auto* call_proto = root;
+
+      if (call_proto->operand_ids_size() != 1) {
+        // Only support single operand forwarding.
+        return CreateAllDynamicResult(subshape, type);
+      }
+      int64_t call_root =
+          handle_to_computation(call_proto->called_computation_ids(0))
+              ->root_id();
+      InferenceContext branch_context = context;
+      branch_context.caller_operand_handles.push_back(
+          call_proto->operand_ids(0));
+      node.AddDependency(call_root, PostorderDFSNodeType::kValueIsDynamic,
+                         branch_context, "callee's root instruction");
+      return node.AddVisit([context](Literal operand) -> StatusOr<Literal> {
+        // Forward result of callee's root to caller.
+        return operand;
+      });
+    }
     case HloOpcode::kConditional: {
       auto node = PostorderDFSNode();
       auto* conditional_proto = root;
@@ -1363,7 +1400,6 @@
       break;
     }
 
-    case HloOpcode::kCall:
     case HloOpcode::kRecv:
     case HloOpcode::kRecvDone:
     case HloOpcode::kSend:
@@ -1671,12 +1707,12 @@
 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
     XlaOp op, ValueInferenceMode mode) {
   TF_RETURN_IF_ERROR(builder_->LookUpInstructionByHandle(op.handle()).status());
-
   PostorderDFSVisitor visitor(
       [&](int64_t handle) {
         return builder_->LookUpInstructionByHandle(handle);
       },
       [&](int64_t handle) { return &(builder_->embedded_[handle]); });
+  TF_ASSIGN_OR_RETURN(Shape op_shape, builder_->GetShape(op));
   int64_t handle = op.handle();
   if (ShapeUtil::IsScalar(builder_->GetShape(op).ValueOrDie())) {
     TF_ASSIGN_OR_RETURN(auto result, SimplifyOp(handle));
@@ -1687,33 +1723,47 @@
   }
   switch (mode) {
     case ValueInferenceMode::kLowerBound: {
+      TF_ASSIGN_OR_RETURN(Literal mask,
+                          visitor.PostOrderDFSVisit(
+                              handle, PostorderDFSNodeType::kBoundIsDynamic));
+      if (mask.IsAll(1)) {
+        // Everything is dynamic, no need to do constant inference.
+        return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+      }
       TF_ASSIGN_OR_RETURN(
           Literal value,
           visitor.PostOrderDFSVisit(handle,
                                     PostorderDFSNodeType::kConstantLowerBound));
-      TF_ASSIGN_OR_RETURN(Literal mask,
-                          visitor.PostOrderDFSVisit(
-                              handle, PostorderDFSNodeType::kBoundIsDynamic));
+
       return OptionalLiteral(std::move(value), std::move(mask));
     }
     case ValueInferenceMode::kUpperBound: {
+      TF_ASSIGN_OR_RETURN(Literal mask,
+                          visitor.PostOrderDFSVisit(
+                              handle, PostorderDFSNodeType::kBoundIsDynamic));
+      if (mask.IsAll(1)) {
+        // Everything is dynamic, no need to do constant inference.
+        return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+      }
       TF_ASSIGN_OR_RETURN(
           Literal value,
           visitor.PostOrderDFSVisit(handle,
                                     PostorderDFSNodeType::kConstantUpperBound));
-      TF_ASSIGN_OR_RETURN(Literal mask,
-                          visitor.PostOrderDFSVisit(
-                              handle, PostorderDFSNodeType::kBoundIsDynamic));
 
       return OptionalLiteral(std::move(value), std::move(mask));
     }
     case ValueInferenceMode::kValue: {
-      TF_ASSIGN_OR_RETURN(Literal value,
-                          visitor.PostOrderDFSVisit(
-                              handle, PostorderDFSNodeType::kConstantValue));
       TF_ASSIGN_OR_RETURN(Literal mask,
                           visitor.PostOrderDFSVisit(
                               handle, PostorderDFSNodeType::kValueIsDynamic));
+      if (mask.IsAll(1)) {
+        // Everything is dynamic, no need to do constant inference.
+        return OptionalLiteral(CreateGarbageLiteral(op_shape), std::move(mask));
+      }
+      TF_ASSIGN_OR_RETURN(Literal value,
+                          visitor.PostOrderDFSVisit(
+                              handle, PostorderDFSNodeType::kConstantValue));
+
       return OptionalLiteral(std::move(value), std::move(mask));
     }
   }
diff --git a/tensorflow/compiler/xla/tests/value_inference_test.cc b/tensorflow/compiler/xla/tests/value_inference_test.cc
index 55bd8b0..58f65ce 100644
--- a/tensorflow/compiler/xla/tests/value_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/value_inference_test.cc
@@ -372,6 +372,32 @@
   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
 }
 
+TEST_F(DynamismInferenceTest, InferThroughCall) {
+  // The result of following call instruction is static.
+  //
+  // Callee:
+  //   p = param
+  //   return p
+  //
+  // Entry:
+  //   c = constant(3)
+  //   return call(c), callee
+  //
+  //
+
+  auto s32_shape = ShapeUtil::MakeShape(S32, {});
+  XlaBuilder call_builder("call");
+  Parameter(&call_builder, 0, s32_shape, "call_param");
+  auto call_computation = call_builder.Build().ValueOrDie();
+
+  XlaBuilder b(TestName());
+  auto constant = ConstantR0<int32>(&b, 3);
+  auto call = Call(&b, call_computation, {constant});
+  ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
+  // Result is static.
+  EXPECT_EQ(ComputeDynamismScalar(call, &b, {}).ValueOrDie(), false);
+}
+
 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
   // The result of following conditional is dynamic.
   // pred = .. # a dynamic value