Allow value inference to analyze into embedded conditional computations.

PiperOrigin-RevId: 369016974
Change-Id: I8c37bedb8a2621b3a65c6270a01b585f7b33a2d6
diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc
index 3271347..7dd2395 100644
--- a/tensorflow/compiler/xla/client/value_inference.cc
+++ b/tensorflow/compiler/xla/client/value_inference.cc
@@ -314,9 +314,8 @@
     case HloOpcode::kAbs:
     case HloOpcode::kDivide:
     case HloOpcode::kGetDimensionSize: {
-      return InvalidArgument(
-          "AnalyzeConstantValueFallback can't handle opcode: %s",
-          root->opcode());
+      return InvalidArgument("AnalyzeConstantValue can't handle opcode: %s",
+                             root->opcode());
     }
     case HloOpcode::kGetTupleElement: {
       int64 operand_handle = root->operand_ids(0);
@@ -332,52 +331,6 @@
         });
       }
 
-      if (operand_opcode == HloOpcode::kConditional) {
-        int64 index = root->tuple_index();
-        auto node = PostorderDFSNode();
-        auto* conditional_proto = operand_proto;
-        // Add dependencies to analyze the predicate of the conditional.
-        node.AddDependency(conditional_proto->operand_ids(0),
-                           PostorderDFSNodeType::kConstantValue)
-            .AddDependency(conditional_proto->operand_ids(0),
-                           PostorderDFSNodeType::kValueIsDynamic);
-        const int64 branch_size =
-            conditional_proto->called_computation_ids_size();
-        for (int64 i = 0; i < branch_size; ++i) {
-          int64 branch_root = handle_to_computation(
-                                  conditional_proto->called_computation_ids(i))
-                                  ->root_id();
-          int64 branch_root_operand =
-              handle_to_instruction(branch_root)->operand_ids(index);
-          node.AddDependency(branch_root_operand,
-                             PostorderDFSNodeType::kConstantValue);
-        }
-        return node.AddVisit(
-            [](absl::Span<Literal> operands) -> StatusOr<Literal> {
-              int64 pred_is_dynamic = operands[1].Get<bool>({});
-              if (pred_is_dynamic) {
-                // If predicate is dynamic, return the value of the first branch
-                // -- if all branches return the same value, this is the value
-                // that we want. If not, the value will be masked anyway so the
-                // value inside doesn't matter.
-                return std::move(operands[2]);
-              } else {
-                // If predicate is static, return the value of given branch.
-                int64 branch_index = 0;
-                if (operands[0].shape().element_type() == PRED) {
-                  if (operands[0].Get<bool>({})) {
-                    branch_index = 0;
-                  } else {
-                    branch_index = 1;
-                  }
-                } else {
-                  branch_index = operands[0].GetIntegralAsS64({}).value();
-                }
-                const int64 branch_dynamism_index = 2 + branch_index;
-                return std::move(operands[branch_dynamism_index]);
-              }
-            });
-      }
       return result.AddVisit([root](absl::Span<Literal> operands) {
         return HloProtoEvaluator(*root)
             .WithOperands(operands)
@@ -624,7 +577,6 @@
   const HloInstructionProto* root = handle_to_instruction(handle);
   // Invariant check.
   TF_RET_CHECK(root);
-  VLOG(1) << "Analyzing IsDynamic on " << root->DebugString();
   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
   PostorderDFSNode result;
   for (auto operand_id : root->operand_ids()) {
@@ -722,88 +674,11 @@
       TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
                           StringToHloOpcode(operand_proto->opcode()));
       if (operand_opcode == HloOpcode::kParameter) {
-        return PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
-          // As an optimization, don't materialize the whole parameter if it's
-          // followed by a GTE.
+        PostorderDFSNode().AddVisit([root]() -> StatusOr<Literal> {
+          // Don't materialize the whole parameter if it's followed by a GTE.
           return CreatePredLiteral(true, Shape(root->shape()));
         });
       }
-
-      if (operand_opcode == HloOpcode::kConditional) {
-        int64 index = root->tuple_index();
-        auto* conditional_proto = operand_proto;
-        auto node = PostorderDFSNode();
-        // Add dependencies to analyze the predicate of the conditional.
-        node.AddDependency(conditional_proto->operand_ids(0),
-                           PostorderDFSNodeType::kConstantValue)
-            .AddDependency(conditional_proto->operand_ids(0),
-                           PostorderDFSNodeType::kValueIsDynamic);
-        const int64 branch_size =
-            conditional_proto->called_computation_ids_size();
-        for (int64 i = 0; i < branch_size; ++i) {
-          int64 branch_root = handle_to_computation(
-                                  conditional_proto->called_computation_ids(i))
-                                  ->root_id();
-          int64 branch_root_operand =
-              handle_to_instruction(branch_root)->operand_ids(index);
-          node.AddDependency(branch_root_operand,
-                             PostorderDFSNodeType::kConstantValue)
-              .AddDependency(branch_root_operand,
-                             PostorderDFSNodeType::kValueIsDynamic);
-        }
-        // Predicate uses 2 dependencies:
-        // 0: Predicate value.
-        // 1: Predicate is dynamic.
-        // Each branch i has 2 dependenices:
-        // 2*i: Branch result value
-        // 2*i + 1: Branch value is dynamic.
-        return node.AddVisit([root, branch_size](absl::Span<Literal> operands)
-                                 -> StatusOr<Literal> {
-          int64 pred_is_dynamic = operands[1].Get<bool>({});
-          auto result = CreatePredLiteral(true, Shape(root->shape()));
-          if (pred_is_dynamic) {
-            // If predicate is dynamic, the result is only static if all
-            // branches are static and return the same value.
-            auto result = CreatePredLiteral(true, Shape(root->shape()));
-
-            result.MutableEachCell<bool>(
-                [&](absl::Span<const int64> indices, bool value) {
-                  string branch_value = operands[2].GetAsString(indices, {});
-                  for (int64 i = 0; i < branch_size; ++i) {
-                    const int64 branch_value_index = 2 + 2 * i;
-                    const int64 branch_dynamism_index = 2 + 2 * i + 1;
-                    auto branch_is_dynamic =
-                        operands[branch_dynamism_index].Get<bool>(indices);
-                    if (branch_is_dynamic) {
-                      return true;
-                    }
-                    if (branch_value !=
-                        operands[branch_value_index].GetAsString(indices, {})) {
-                      return true;
-                    }
-                  }
-                  // Value of the branch is static.
-                  return false;
-                });
-            return result;
-          } else {
-            // If predicate is static, return true if given branch result
-            // value is dynamic.
-            int64 branch_index = 0;
-            if (operands[0].shape().element_type() == PRED) {
-              if (operands[0].Get<bool>({})) {
-                branch_index = 0;
-              } else {
-                branch_index = 1;
-              }
-            } else {
-              branch_index = operands[0].GetIntegralAsS64({}).value();
-            }
-            const int64 branch_dynamism_index = 2 + 2 * branch_index + 1;
-            return std::move(operands[branch_dynamism_index]);
-          }
-        });
-      }
       return result.AddVisit([root](absl::Span<Literal> operands) {
         return HloProtoEvaluator(*root)
             .WithOperands(operands)
@@ -923,7 +798,7 @@
       break;
     }
     default:
-      return Unimplemented("Can't infer dynamism through %s: %s",
+      return Unimplemented("Can't infer upper bound through %s: %s",
                            root->opcode(), root->DebugString());
   }
 }
@@ -977,23 +852,19 @@
     PostorderDFSNode node;
     switch (item.type) {
       case PostorderDFSNodeType::kConstantValue: {
-        VLOG(1) << "constant value";
-        TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
+      TF_ASSIGN_OR_RETURN(node, AnalyzeConstant(item.handle));
         break;
       }
       case PostorderDFSNodeType::kConstantLowerBound: {
-        VLOG(1) << "constant lower bound";
-        TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
+      TF_ASSIGN_OR_RETURN(node, AnalyzeLowerBound(item.handle));
         break;
       }
       case PostorderDFSNodeType::kConstantUpperBound: {
-        VLOG(1) << "constant upper bound";
         TF_ASSIGN_OR_RETURN(node, AnalyzeUpperBound(item.handle));
         break;
       }
       case PostorderDFSNodeType::kBoundIsDynamic:
       case PostorderDFSNodeType::kValueIsDynamic: {
-        VLOG(1) << "value is dynamic";
         TF_ASSIGN_OR_RETURN(node, AnalyzeIsDynamic(item.handle, item.type));
         break;
       }
@@ -1016,9 +887,8 @@
         return builder_->LookUpInstructionByHandle(handle).ValueOrDie();
       },
       [&](int64 handle) { return &(builder_->embedded_[handle]); });
-  auto result = visitor.PostOrderDFSVisit(
-      op.handle(), PostorderDFSNodeType::kValueIsDynamic);
-  return result;
+  return visitor.PostOrderDFSVisit(op.handle(),
+                                   PostorderDFSNodeType::kValueIsDynamic);
 }
 
 StatusOr<OptionalLiteral> ValueInference::AnalyzeConstant(
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5f5aba6..9d249c1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -3767,15 +3767,6 @@
     }
 
     int64 computation_id = imported_computation.id();
-    for (int64 i = 0; i < imported_computation.instructions_size(); ++i) {
-      ImportedInstruction imported_instruction;
-      imported_instruction.computation_id = computation_id;
-      imported_instruction.instruction_index = i;
-      handle_to_imported_index_.insert(
-          {imported_computation.instructions(i).id(),
-           ImportedInstruction{.computation_id = computation_id,
-                               .instruction_index = i}});
-    }
     embedded_.insert({computation_id, std::move(imported_computation)});
   }
 }
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 6926fbc..df8ba65 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -1003,15 +1003,6 @@
   // instruction is held.
   absl::flat_hash_map<int64, int64> handle_to_index_;
 
-  // Track imported instructions by their computation id and the position in
-  // their computation's instruction list.
-  struct ImportedInstruction {
-    int64 computation_id;
-    int64 instruction_index;
-  };
-
-  absl::flat_hash_map<int64, ImportedInstruction> handle_to_imported_index_;
-
   // The embedded computations used by this computation. Each computation was
   // the entry computation of some XlaComputation, the key is the unique id of
   // that XlaComputation.
@@ -1454,14 +1445,6 @@
       int64 handle) const {
     auto it = handle_to_index_.find(handle);
     if (it == handle_to_index_.end()) {
-      // Try look for the instruction in the imported instructions.
-      auto imported_it = handle_to_imported_index_.find(handle);
-      if (imported_it != handle_to_imported_index_.end()) {
-        ImportedInstruction imported = imported_it->second;
-        return const_cast<InstructionType>(
-            &embedded_.at(imported.computation_id)
-                 .instructions(imported.instruction_index));
-      }
       return InvalidArgument("No XlaOp with handle %d", handle);
     }
     return const_cast<InstructionType>(&instructions_.at(it->second));
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index 69c8083..a5019fa 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -298,139 +298,6 @@
   EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
 }
 
-TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
-  // The result of following conditional is static.
-  // pred = .. # a dynamic value
-  // if (pred) {
-  //  return (1) # both branches return the same value
-  // } else {
-  //  return (1)
-  // }
-  //
-
-  auto s32_shape = ShapeUtil::MakeShape(S32, {});
-  auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
-  XlaBuilder true_builder("true");
-  Parameter(&true_builder, 0, s32_shape, "cond_param");
-  Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
-  auto true_computation = true_builder.Build().ValueOrDie();
-
-  XlaBuilder false_builder("false");
-  Parameter(&false_builder, 0, s32_shape, "cond_param");
-  Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 1)});
-  auto false_computation = false_builder.Build().ValueOrDie();
-
-  XlaBuilder b(TestName());
-  auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
-  auto constant = ConstantR0<int32>(&b, 0);
-  auto cond = Conditional(parameter, constant, true_computation, constant,
-                          false_computation);
-  auto gte = GetTupleElement(cond, 0);
-  ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
-  // Result is not dynamic.
-  EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
-  // The result of following conditional is dynamic.
-  // pred = .. # a dynamic value
-  // if (pred) {
-  //  return (1) # These two branches return different values.
-  // } else {
-  //  return (2)
-  // }
-  //
-
-  auto s32_shape = ShapeUtil::MakeShape(S32, {});
-  auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
-  XlaBuilder true_builder("true");
-  Parameter(&true_builder, 0, s32_shape, "cond_param");
-  Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
-  auto true_computation = true_builder.Build().ValueOrDie();
-
-  XlaBuilder false_builder("false");
-  Parameter(&false_builder, 0, s32_shape, "cond_param");
-  Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 2)});
-  auto false_computation = false_builder.Build().ValueOrDie();
-
-  XlaBuilder b(TestName());
-  auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
-  auto constant = ConstantR0<int32>(&b, 0);
-  auto cond = Conditional(parameter, constant, true_computation, constant,
-                          false_computation);
-  auto gte = GetTupleElement(cond, 0);
-  ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
-  // Result is dynamic.
-  EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
-  // The result of following conditional is static.
-  // pred = true
-  // if (pred) {
-  //  return (1)
-  // } else {
-  //  return (..dynamic_value...)
-  // }
-  //
-
-  auto s32_shape = ShapeUtil::MakeShape(S32, {});
-  auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
-  XlaBuilder true_builder("true");
-  Parameter(&true_builder, 0, s32_shape, "cond_param");
-  Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
-  auto true_computation = true_builder.Build().ValueOrDie();
-
-  XlaBuilder false_builder("false");
-  Tuple(&false_builder,
-        {Parameter(&false_builder, 0, s32_shape, "cond_param")});
-  auto false_computation = false_builder.Build().ValueOrDie();
-
-  XlaBuilder b(TestName());
-  auto pred = ConstantR0<bool>(&b, true);
-  auto constant = ConstantR0<int32>(&b, 0);
-  auto cond = Conditional(pred, constant, true_computation, constant,
-                          false_computation);
-  auto gte = GetTupleElement(cond, 0);
-  ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
-  // Result is not dynamic.
-  EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
-TEST_F(DynamismInferenceTest,
-       InferThroughConditionalPredIsConstantFalseBranch) {
-  // The result of following conditional is dynamic.
-  // pred = false
-  // if (pred) {
-  //  return (1)
-  // } else {
-  //  return (..dynamic_value...)
-  // }
-  //
-
-  auto s32_shape = ShapeUtil::MakeShape(S32, {});
-  auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
-  XlaBuilder true_builder("true");
-  Parameter(&true_builder, 0, s32_shape, "cond_param");
-  Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
-  auto true_computation = true_builder.Build().ValueOrDie();
-
-  XlaBuilder false_builder("false");
-  Tuple(&false_builder,
-        {Parameter(&false_builder, 0, s32_shape, "cond_param")});
-  auto false_computation = false_builder.Build().ValueOrDie();
-
-  XlaBuilder b(TestName());
-  auto pred = ConstantR0<bool>(&b, false);
-  auto constant = ConstantR0<int32>(&b, 0);
-  auto cond = Conditional(pred, constant, true_computation, constant,
-                          false_computation);
-  auto gte = GetTupleElement(cond, 0);
-  ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
-  // Result is dynamic.
-  EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
-}
-
 class UpperBoundInferenceTest : public ValueInferenceTest {
  public:
   explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)