[XLA/Bridge] Provide a flag not to resolve compile time constants.
For some models, resolving constant expressions in the bridge leads to
non-linear slowdown.
PiperOrigin-RevId: 280205828
Change-Id: Icaa0394d8045c31544b66123857da7cb1b1701b9
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index 53f9b70..35dee15 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -161,6 +161,9 @@
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
+ Flag("tf_xla_noresolve_compile_time_constants",
+ &ops_flags->tf_xla_noresolve_compile_time_constants,
+ "Do not perform constant folding in XlaCompiler::CompileGraph"),
Flag("tf_introduce_floating_point_jitter_to_tensors",
setter_for_jitter_tensor_names, "",
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index 9307874..baed7ad 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -91,6 +91,14 @@
// If true, _XlaCompile always refuses to compile the cluster, which means the
// XLA clusters always run in the TF executor. Defaults to false.
bool tf_xla_always_defer_compilation;
+
+ // If true, sets compile_options.resolve_compile_time_constants to false,
+ // which stops the bridge from using the HloEvaluator for constant resolution
+ // in XlaCompiler::CompileGraph.
+ //
+ // For some models, constant folding during compile graph experiences a
+ // non-linear blow up, which overshadows both compilation and execution.
+ bool tf_xla_noresolve_compile_time_constants;
};
// Flags for the build_xla_ops pass.
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 0e8bce3..edb19bc 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -326,11 +326,8 @@
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
+ compile_options.resolve_compile_time_constants =
+ !GetXlaOpsCommonFlags().tf_xla_noresolve_compile_time_constants;
// Optimization: where possible, have the computation return a naked array
// rather than a one-element tuple.
compile_options.always_return_tuple = false;