| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ |
| #define TENSORFLOW_COMPILER_JIT_FLAGS_H_ |
| |
| #include <vector> |
| |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/util/command_line_flags.h" |
| |
| namespace tensorflow { |
| |
| struct XlaAutoJitFlag { |
| // Control compilation of operators into XLA computations on CPU and GPU |
| // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very |
| // likely to be improved; 2 = on for everything. |
| // |
| // If all non-CPU ops in the graph being optimized are placed on a single GPU |
| // and there is at least one node placed on that GPU then |
| // `optimization_level_single_gpu` applies. Otherwise |
| // `optimization_level_general` applies. |
| // |
| // Experimental. |
| int32 optimization_level_single_gpu; |
| int32 optimization_level_general; |
| }; |
| |
| // Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax |
| // is: |
| // <number>: sets general and single_gpu setting to the provided number. |
| // single-gpu(<number>): sets the single_gpu setting to the provided number. |
| bool SetXlaAutoJitFlagFromFlagString(const string& value); |
| |
| // Flags associated with the XLA bridge's mark_for_compilation_pass module. |
| struct MarkForCompilationPassFlags { |
| XlaAutoJitFlag xla_auto_jit_flag; |
| |
| // Minimum number of operators in an XLA compilation. Ignored for operators |
| // placed on an XLA device or operators explicitly marked for compilation. |
| int32 tf_xla_min_cluster_size; |
| |
| // Maximum number of operators in an XLA compilation. |
| int32 tf_xla_max_cluster_size; |
| |
| // Dump graphs during XLA compilation. |
| bool tf_xla_clustering_debug; |
| |
| // Enables global JIT compilation for CPU via SessionOptions. |
| bool tf_xla_cpu_global_jit; |
| |
| // "Compiler fuel" for clustering. Only this many ops will be marked as |
| // eligible for clustering. |
| int64 tf_xla_clustering_fuel; |
| |
| // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then |
| // we do not do deadness related safety checks. This is unsound in general, |
| // but can be used as a debugging aid. |
| bool tf_xla_disable_deadness_safety_checks_for_debugging; |
| |
| // If tf_xla_disable_resource_variable_safety_checks_for_debugging is set to |
| // true then we do not do safety checks to preserve TensorFlow's resource |
| // variable concurrency semantics. This is unsound in general, but can be |
| // used as a debugging aid. |
| bool tf_xla_disable_resource_variable_safety_checks_for_debugging; |
| }; |
| |
| // Flags associated with the XLA bridge's xla_device module. |
| struct XlaDeviceFlags { |
| // Switch the CPU device into "on-demand" mode, where instead of |
| // autoclustering ops are compiled one by one just-in-time. |
| // Enabling this mode by a legacy flag is a temporary mechanism. When this |
| // feature is battle-tested, we will switch this to be a session option. |
| bool tf_xla_compile_on_demand; |
| }; |
| |
| // Flags common to the _Xla* ops and their kernels. |
| struct XlaOpsCommonFlags { |
| // If true, _XlaCompile always refuses to compile the cluster, which means the |
| // XLA clusters always run in the TF executor. Defaults to false. |
| bool tf_xla_always_defer_compilation; |
| |
| // If true, sets compile_options.resolve_compile_time_constants to false, |
| // which stops the bridge from using the HloEvaluator for constant resolution |
| // in XlaCompiler::CompileGraph. |
| // |
| // For some models, constant folding during compile graph experiences a |
| // non-linear blow up, which overshadows both compilation and execution. |
| bool tf_xla_noresolve_compile_time_constants; |
| }; |
| |
| // Flags for the build_xla_ops pass. |
| struct BuildXlaOpsPassFlags { |
| // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. |
| // Defaults to true. |
| bool tf_xla_enable_lazy_compilation; |
| |
| // If true then insert Print nodes to print out values produced by XLA |
| // clusters. Useful for debugging. |
| bool tf_xla_print_cluster_outputs; |
| |
| // If true, insert CheckNumerics nodes for every floating point typed input to |
| // an XLA cluster. |
| bool tf_xla_check_cluster_input_numerics; |
| |
| // If true, insert CheckNumerics nodes for every floating point typed output |
| // from an XLA cluster. |
| bool tf_xla_check_cluster_output_numerics; |
| |
| // Disables all constant folding. The primary use for this is for testing to |
| // guarantee that tests are run on XLA and not on TF's CPU implementation. |
| bool tf_xla_disable_constant_folding; |
| }; |
| |
| // Flags for the IntroduceFloatingPointJitter pass. |
| struct IntroduceFloatingPointJitterPassFlags { |
| // The amount of jitter to introduce. This amount is added to each element in |
| // the tensors named in `tensor_names. |
| float jitter_amount; |
| |
| // The Tensors to add the jitter to. The tensors are named in the TensorId |
| // format of <node name>:<output idx>. |
| std::vector<string> tensor_names; |
| }; |
| |
| // Return a pointer to the DumpGraphFlags struct; |
| // repeated calls return the same pointer. |
| // This should be called only after Flags::Parse() has returned. |
| |
| // Getters for flags structs defined above. The first call to any of these |
| // parses TF_XLA_FLAGS for all of them. Those functions which return a pointer |
| // always return the same pointer. |
| MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); |
| BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); |
| XlaDeviceFlags* GetXlaDeviceFlags(); |
| const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); |
| |
| const IntroduceFloatingPointJitterPassFlags& |
| GetIntroduceFloatingPointJitterPassFlags(); |
| |
| // Appends the flag definitions associated with |
| // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. |
| // |
| // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. |
| void AppendMarkForCompilationPassFlags( |
| std::vector<tensorflow::Flag>* flag_list); |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ |