Always propagate constants for while loops, including V1.
PiperOrigin-RevId: 368897854
Change-Id: I16609b36680c2018f1d91b1741afd90a5feb6a6a
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 14f03b1..dbbf6cd 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -285,10 +285,6 @@
} else {
has_token_input_output_ = !token_input_nodes_.empty();
}
- if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
- &propagate_compile_time_consts_));
- }
if (!ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
&original_node_name_)
.ok())
@@ -325,12 +321,10 @@
// with the const args.
std::vector<bool> compile_time_const_arg_indices(ctx->num_inputs());
int num_compile_time_const_args = 0;
- if (propagate_compile_time_consts_) {
- OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst(
- ctx, body_name_attr_, cond_name_attr_, &arguments,
- &compile_time_const_arg_indices,
- &num_compile_time_const_args, compiler->client()));
- }
+ OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst(
+ ctx, body_name_attr_, cond_name_attr_, &arguments,
+ &compile_time_const_arg_indices,
+ &num_compile_time_const_args, compiler->client()));
VLOG(1) << "Compiling body";
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 0e259b3..ec8875f 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -59,10 +59,6 @@
bool has_token_input_output_;
std::vector<string> token_input_nodes_;
string original_node_name_;
- // Whether to propagate compile time consts into the loop body.
- // This is not supported by default now since it may cause HBM memory
- // overheads.
- bool propagate_compile_time_consts_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};