blob: efc9a67e9bfbe77664c47777243a601691cbe979 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package xla;
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla_data.proto";
// Debugging options for XLA. These options may change at any time - there are
// no guarantees about backward or forward compatibility for these fields.
message DebugOptions {
// Show addresses of HLO ops in graph dump.
bool xla_hlo_graph_addresses = 2;
// Instrument the computation to collect per-HLO cycle counts.
bool xla_hlo_profile = 9;
// List of HLO passes to disable/enable. These names must exactly match the
// pass names as specified by the HloPassInterface::name() method.
//
// At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must
// be empty.
repeated string xla_disable_hlo_passes = 30;
repeated string xla_enable_hlo_passes_only = 124;
// Disables all HLO passes. Notes that some passes are necessary for
// correctness and the invariants that must be satisfied by "fully optimized"
// HLO are different for different devices and may change over time. The only
// "guarantee", such as it is, is that if you compile XLA and dump the
// optimized HLO for some graph, you should be able to run it again on the
// same device with the same build of XLA.
bool xla_disable_all_hlo_passes = 104;
// Numerical optimization level for the XLA compiler backend; the specific
// interpretation of this value is left to the backends.
int32 xla_backend_optimization_level = 31;
// Embed the compiler IR as a string in the executable.
bool xla_embed_ir_in_executable = 33;
// Eliminate implicit broadcasts when lowering user computations to HLO
// instructions; use explicit broadcast instead.
bool xla_eliminate_hlo_implicit_broadcast = 35;
// When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
// mode.
bool xla_cpu_multi_thread_eigen = 60;
// Path to directory with cuda/ptx tools and libraries.
string xla_gpu_cuda_data_dir = 61;
// Enable flush-to-zero semantics in the GPU backend.
bool xla_gpu_ftz = 62;
// Disable multi-streaming in the GPU backend.
bool xla_gpu_disable_multi_streaming = 63;
// Debugging feature: if enabled, the GPU backend will assign HLO operators to
// randomly chosen streams. This is intended to trigger concurrency bugs.
bool xla_gpu_use_random_streams = 134;
// If true, in LLVM-based backends, emit !alias.scope metadata in
// generated IR.
bool xla_llvm_enable_alias_scope_metadata = 70;
// If true, in LLVM-based backends, emit !noalias metadata in the
// generated IR.
bool xla_llvm_enable_noalias_metadata = 71;
// If true, in LLVM-based backends, emit !invariant.load metadata in
// the generated IR.
bool xla_llvm_enable_invariant_load_metadata = 72;
// If true, a set of expensive LLVM optimization passes will not be run.
bool xla_llvm_disable_expensive_passes = 73;
reserved 80; // Was hlo_reduce_precision_options
// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
// computation will run n! times with all permunations of layouts for the
// output shape in rank n. For example, with a 3D shape, all permutations of
// the set {0, 1, 2} are tried.
bool xla_test_all_output_layouts = 90;
// This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
// computation will run for all permunations of layouts of all input
// arguments. For example, with 2 input arguments in 2D and 4D shapes, the
// computation will run 2! * 4! times.
bool xla_test_all_input_layouts = 91;
// Assign colors based on sharding information when generating the Graphviz
// HLO graph.
bool xla_hlo_graph_sharding_color = 92;
reserved 93; // Was xla_hlo_tfgraph_device_scopes
reserved 94; // Was xla_gpu_use_cudnn_batchnorm
// Generate calls to MKL-DNN in the CPU backend.
bool xla_cpu_use_mkl_dnn = 97;
// Maximum kernel unroll factor for the GPU backend.
int32 xla_gpu_max_kernel_unroll_factor = 98;
// When true, "unsafe" mathematical optimizations are enabled. These
// transformations include but are not limited to:
//
// - Reducing the precision of operations (e.g. using an approximate sin
// function, or transforming x/y into x * (1/y)).
// - Assuming that operations never produce or consume NaN or +/- Inf (this
// behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
// - Assuming that +0 and -0 are indistinguishable.
bool xla_cpu_enable_fast_math = 99;
// When xla_cpu_enable_fast_math is true then this controls whether we allow
// operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is
// false.
bool xla_cpu_fast_math_honor_nans = 120;
// When xla_cpu_enable_fast_math is true then this controls whether we allow
// operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
// false.
bool xla_cpu_fast_math_honor_infs = 121;
// When xla_cpu_enable_fast_math is true then this controls whether we forbid
// to use the reciprocal of an argument instead of division. Ignored when
// xla_cpu_enable_fast_math is false.
bool xla_cpu_fast_math_honor_division = 126;
// When xla_cpu_enable_fast_math is true then this controls whether we forbid
// to approximate calculations for functions. Ignored when
// xla_cpu_enable_fast_math is false.
bool xla_cpu_fast_math_honor_functions = 129;
// When false we lower the Minimum and Maximum hlos in the CPU backend such
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag
// this is false we always propagate NaNs through Min and Max.
//
// Note, this does not correspond to the exact same behavior as the gpu flag
// below!
bool xla_cpu_enable_fast_min_max = 140;
// When true we lower the Minimum and Maximum hlos in the GPU backend such
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag
// this is true we don't propagate NaNs through Min and Max.
//
// Note, this does not correspond to the exact same behavior as the cpu flag
// above!
bool xla_gpu_enable_fast_min_max = 100;
// Allows xla to increase the output precision of floating point operations.
bool xla_allow_excess_precision = 122;
// Crashes the program when any kind of verification fails, instead of just
// logging the failures. One example is cross checking of convolution results
// among different algorithms.
bool xla_gpu_crash_on_verification_failures = 101;
// 0: Disable gemm and convolution autotuning.
// 1: Enable autotuning, but disable correctness checking.
// 2: Also set output buffers to random numbers during autotuning.
// 3: Also reset output buffers to random numbers after autotuning each
// algorithm.
// 4+: Also check for correct outputs and for out-of-bounds reads/writes.
//
// Default: 4.
int32 xla_gpu_autotune_level = 123;
// Force the host platform to pretend that there are these many host
// "devices". All these devices are backed by the same threadpool. Defaults
// to 1.
//
// Setting this to anything other than 1 can increase overhead from context
// switching but we let the user override this behavior to help run tests on
// the host that run models in parallel across multiple devices.
int32 xla_force_host_platform_device_count = 102;
// If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
bool xla_gpu_disable_gpuasm_optimizations = 103;
// Enable fast math with eigen in the HLO evaluator.
bool xla_hlo_evaluator_use_fast_path = 106;
// Temporary option to allow support for both the R1 and the scalar index
// versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.
bool xla_allow_scalar_index_dynamic_ops = 107;
enum StepMarkerLocation {
// Generate a step marker at the program entry. This handles the case where
// each step is done by one or multiple program execution(s). Only the first
// program will be tagged for generating a step marker at the program entry.
// This is the default.
STEP_MARK_AT_ENTRY = 0;
// Generate a step marker at each iteration of the top level while loop,
// which is assumed to be a training loop.
STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1;
// Generate a step marker at each iteration of the second level while loops,
// which is assumed to be a training or eval loop.
STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3;
// No step marker generated.
STEP_MARK_NONE = 2;
}
// Option to emit a target-specific marker to indicate the start of a training
// step. The location of the marker (if any) is determined by the option
// value.
StepMarkerLocation xla_step_marker_location = 108;
//
// BEGIN flags controlling dumping HLO modules for debugging.
//
// When dumping is enabled, HLO modules dumped at the very beginning and end
// of compilation, and optionally also during the pass pipeline.
//
// In general, if you set one of these flags, we will try to infer reasonable
// defaults for the others. For example:
//
// * Setting --xla_dump_to=/tmp/foo without specifying a format
// with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text.
//
// * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will
// dump to stdout.
//
// Directory to dump into.
string xla_dump_to = 109;
// If specified, will only dump modules which match this regexp.
string xla_dump_hlo_module_re = 110;
// If this flag is specified, will also dump HLO before and after passes that
// match this regular expression. Set to .* to dump before/after all passes.
string xla_dump_hlo_pass_re = 111;
// Specifies the format that HLO is dumped in. Multiple of these may be
// specified.
bool xla_dump_hlo_as_text = 112;
bool xla_dump_hlo_as_proto = 113;
bool xla_dump_hlo_as_dot = 114;
bool xla_dump_hlo_as_url = 115;
// Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
bool xla_dump_hlo_as_html = 116;
// Dump the visualization of the fusion progress.
bool xla_dump_fusion_visualization = 149;
// If true, every time an HLO module is run, we will dump an HloSnapshot
// (essentially, a serialized module plus its inputs) to the --xla_dump_to
// directory.
bool xla_dump_hlo_snapshots = 118;
// Include a timestamp in the dumped filenames.
bool xla_dump_include_timestamp = 131;
// Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
int32 xla_dump_max_hlo_modules = 132;
// Dump HloModuleMetadata as a text proto for each HLO module.
bool xla_dump_module_metadata = 144;
// GZip-compress protos dumped via --xla_dump_hlo_as_proto.
bool xla_dump_compress_protos = 151;
// Dump HLO in long text format. Ignored unless xla_dump_hlo_as_text is true.
bool xla_dump_hlo_as_long_text = 164;
//
// END flags controlling dumping HLO modules.
//
// Overrides for XLA GPU's convolution layout heuristic.
bool xla_gpu_force_conv_nchw = 125;
bool xla_gpu_force_conv_nhwc = 146;
// Paths to files with ptx code.
repeated string xla_gpu_ptx_file = 127;
// Whether to dump llvm ir when compiling to ptx.
bool xla_gpu_dump_llvmir = 155;
// Denylist for cuDNN convolutions.
string xla_gpu_algorithm_denylist_path = 128;
reserved 130; // Was xla_gpu_deterministic_reductions
// Debug options that trigger execution errors when NaN or Inf are detected.
bool xla_tpu_detect_nan = 135;
bool xla_tpu_detect_inf = 136;
// True if TraceMe annotations are enabled for XLA:CPU.
bool xla_cpu_enable_xprof_traceme = 137;
// It is usually preferable to not fallback to the driver; it can consume more
// memory, or have bugs.
bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138;
// Extra parameters to pass the GPU assembler.
string xla_gpu_asm_extra_flags = 141;
// Per-heap size constraint. New heaps will be created if per-heap max size is
// reached.
int32 xla_multiheap_size_constraint_per_heap = 142;
// Enable detailed logging into vlog and xla dumping. If this is disabled, no
// compilation summary will be printed in the end of computation and no hlo
// modules will be dumped.
bool xla_detailed_logging_and_dumping = 143;
// Overrides normal multi-threaded compilation settting to use this many
// threads. Setting to 0 (the default value) means no enforcement.
int32 xla_gpu_force_compilation_parallelism = 147;
// Guarantees run-to-run determinism. At present, the HLO ops Scatter and
// SelectAndScatter do not have deterministic XLA:GPU implementations.
// Compilation errors out if these ops are encountered.
bool xla_gpu_deterministic_ops = 148;
// Paths to files with LLVM code.
repeated string xla_gpu_llvm_ir_file = 150;
// Convert synchronous all-reduces ops into asynchronous.
bool xla_gpu_enable_async_all_reduce = 152;
// Size threshold (in bytes) for the GPU all-reduce combiner.
int64 xla_gpu_all_reduce_combine_threshold_bytes = 157;
// Combine GPU all-reduces into a single operation over a contiguous buffer.
bool xla_gpu_all_reduce_contiguous = 158;
// Number of devices per host for first stage of BlueConnect decomposition
// pass. The pass will attempt to decompose all-reduces ops into a
// ReduceScatter-AllReduce-AllGather sequence, with the initial ReduceScatter
// being performed over all of the devices in the same host. Set to < 1 to
// disable all-reduce decomposition.
int32 xla_gpu_all_reduce_blueconnect_num_devices_per_host = 159;
// Whether to use the cuDNN frontend API for convolutions when possible.
bool xla_gpu_enable_cudnn_frontend = 160;
// Disable dumping metadata in HLO dumps.
bool xla_dump_disable_metadata = 153;
// If this flag is specified, will only dump HLO before and after passes in
// the pass pipeline that matches this regular expression. Default empty value
// enables dumping in all pipelines.
string xla_dump_hlo_pipeline_re = 154;
// If true, abort immediately when conv algorithm picker fails, rather than
// logging a warning and proceeding with fallback.
bool xla_gpu_strict_conv_algorithm_picker = 156;
// If true, enable XLIR to compile gpu programs to TFRT BEF.
bool xla_gpu_bef_executable = 161;
// If true, enable XLIR to compile thunks to TFRT BEF.
// This flag has no effect when xla_gpu_bef_executable is true.
bool xla_gpu_bef_thunk = 162;
// If true, enable XLIR to compile gpu programs to JitRt.
bool xla_gpu_jitrt_executable = 169;
// Timeout in seconds before terminating jobs that are stuck in a NCCL
// Rendezvous. Negative value disables the timeout and will not terminate.
int64 xla_gpu_nccl_termination_timeout_seconds = 163;
// Enables shared constants for XLA/GPU. This allows large constants to be
// shared among multiple GPU executables.
bool xla_gpu_enable_shared_constants = 165;
// Whether to use cuBLASLt for GEMMs on GPUs.
bool xla_gpu_enable_cublaslt = 166;
// Size threshold (in megabytes) for the GPU redzone scratch allocator.
int64 xla_gpu_redzone_scratch_max_megabytes = 167;
// Allows all floating-point conversions to be simplified, including those
// that affect the numerics. The `BFloat16Normalization` pass inserts many
// `f32 -> bf16 -> f32` conversion pairs. These are not removed by the
// `AlgebraicSimplifier`, as that will only simplify conversions that are
// no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy.
bool xla_gpu_simplify_all_fp_conversions = 168;
// Next id: 169
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
reserved 5, 117, 133,
139; // were xla_hlo_dump_as_graphdef, xla_dump_to,
// xla_gpu_use_horizontal_fusion, and
// xla_gpu_unsafe_fallback_to_driver_on_ptxas_error
}
// These settings control how XLA compiles and/or runs code. Not all settings
// will have an effect on every platform.
//
// When adding new fields, keep in mind that boolean fields default to false.
message ExecutionOptions {
// This optional field's layout is used as a hint when storing the output of
// this computation. Subsequent transfers of this output array to the client
// may be faster when using this layout.
//
// We use a Shape here to accommodate computations that return a tuple.
ShapeProto shape_with_output_layout = 2;
// Used to seed random-number generators used in this computation. If this is
// 0, we generate a seed ourselves.
//
// TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
uint64 seed = 3;
DebugOptions debug_options = 4;
// This optional field specifies a particular set of devices to run the
// computation on. The computation will be partitioned across these devices.
// If not provided, the default device will be chosen.
repeated DeviceHandle device_handles = 5;
// Number of replicas of the computation to run. If zero, uses the default
// number of replicas for the XLA service.
int32 num_replicas = 6;
// This optional field specifies the device assignment if known at compile
// time.
DeviceAssignmentProto device_assignment = 7;
// Alias input and output buffers for parameters that are passed-through XLA
// modules without being changed.
bool alias_passthrough_params = 8;
// Number of partitions of the computation to run (model parallelism).
// If zero, uses the default number of partitions for the XLA service.
int32 num_partitions = 9;
// Used to identify a set of programs that should be launch together.
int32 launch_id = 10;
// Indicates whether to use SPMD (true) or MPMD (false) partitioning when
// num_partitions > 1 and XLA is requested to partition the input program.
bool use_spmd_partitioning = 11;
// Whether to automatically generate XLA shardings for SPMD partitioner.
bool use_auto_spmd_partitioning = 15;
// Device mesh shape used to create the sharding search space when
// use_auto_spmd_partitioning=true.
repeated int64 auto_spmd_partitioning_mesh_shape = 16;
// Device mesh ids compatible with the above mesh_shape used when
// use_auto_spmd_partitioning=true.
repeated int64 auto_spmd_partitioning_mesh_ids = 17;
// If set, deduplicate hlo into function calls to reduce binary size. Only
// works on TPU.
bool deduplicate_hlo = 12;
reserved 13; // Was broadcast_replicated_parameters_via_collectives
// Allows sharding propagation to propagate to the outputs. This changes the
// output shape of the computation (which is undesirable), but it can be used
// to allow to run partial compilation to determine what would be the output
// sharding of a computation if XLA would be allowed to propagate the sharding
// which can be used by higher level framework as a way to query intermediate
// sharding of operations when multiple computation would be chained and
// merged together.
bool allow_spmd_sharding_propagation_to_output = 14;
}
message GetDeviceHandlesRequest {
int64 device_count = 1;
}
message GetDeviceHandlesResponse {
repeated DeviceHandle device_handles = 1;
}
message TransferToClientRequest {
GlobalDataHandle data = 1;
// This optional field directs the service to return the literal in this
// layout. A shape is used to hold the layout to accommodate tuples.
ShapeProto shape_with_layout = 2;
}
message TransferToClientResponse {
LiteralProto literal = 1;
}
message TransferToServerRequest {
LiteralProto literal = 1;
DeviceHandle device_handle = 2;
}
message TransferToServerResponse {
GlobalDataHandle data = 1;
}
message TransferToInfeedRequest {
LiteralProto literal = 1;
int64 replica_id = 2;
DeviceHandle device_handle = 3;
}
message TransferToInfeedResponse {}
message TransferFromOutfeedRequest {
// This optional field directs the service to return the literal in this
// layout. A shape is used to hold the layout to accommodate tuples.
ShapeProto shape_with_layout = 1;
int64 replica_id = 2;
DeviceHandle device_handle = 3;
}
message TransferFromOutfeedResponse {
LiteralProto literal = 1;
}
message ResetDeviceRequest {
DeviceHandle device_handle = 1;
}
message ResetDeviceResponse {}
message ComputationGraphStatsRequest {
HloModuleProto computation = 1;
DebugOptions debug_options = 2;
}
message ComputationStatsResponse {
ComputationStats stats = 1;
}
message CreateChannelHandleRequest {
ChannelHandle.ChannelType channel_type = 1;
}
message CreateChannelHandleResponse {
ChannelHandle channel = 1;
}
message UnregisterRequest {
repeated GlobalDataHandle data = 1;
}
message UnregisterResponse {}
message CompileRequest {
// The graph to be compiled.
HloModuleProto computation = 1;
// Options that affect how XLA compiles code to service this request.
ExecutionOptions execution_options = 2;
// The layouts of the input arguments. If not set, the default layout will be
// used. Although the real arguments are not needed in compilation, the
// layouts of the arguments can affect the compilation.
repeated ShapeProto input_shape_with_layout = 3;
}
message CompileResponse {
// The handle to the executable.
ExecutionHandle handle = 1;
}
message ExecuteRequest {
ExecutionHandle handle = 1;
// The shape and layout of the arguments must be the same as the those of the
// executable's parameters.
repeated GlobalDataHandle arguments = 2;
}
// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
// the uses with calls to Compile and Execute.
message ExecuteGraphRequest {
HloModuleProto computation = 1;
repeated GlobalDataHandle arguments = 2;
// Options that affect how XLA compiles and runs code to service this request.
ExecutionOptions execution_options = 3;
}
message ExecuteGraphParallelRequest {
repeated ExecuteGraphRequest requests = 1;
}
message ExecuteResponse {
GlobalDataHandle output = 1;
ExecutionProfile profile = 2;
}
message ExecuteParallelResponse {
repeated ExecuteResponse responses = 1;
}
message WaitForExecutionRequest {
ExecutionHandle execution = 1;
}
message WaitForExecutionResponse {
GlobalDataHandle output = 1;
ExecutionProfile profile = 2;
}
message ComputeConstantGraphRequest {
HloModuleProto computation = 1;
LayoutProto output_layout = 2;
}
message ComputeConstantResponse {
// A LiteralProto is returned directly for this request.
LiteralProto literal = 1;
}
message DeconstructTupleRequest {
GlobalDataHandle tuple_handle = 2;
}
message DeconstructTupleResponse {
repeated GlobalDataHandle element_handles = 1;
}
message LoadDataRequest {
// Describes the path of the ColumnIO tablet to load.
string columnio_tablet_path = 1;
// Describes the field to load within the ColumnIO tablet.
string columnio_field = 2;
// Individual element shape, excluding rows.
ShapeProto element_shape = 3;
// Warning: ColumnIO does not support random-access, so use offset with
// caution in performance-critical scenarios.
int64 offset = 4;
// Maximum number of elements (with shape element_shape) to load.
int64 limit = 5;
// If more than one item is requested (via limit > 1), then this request
// attribute zips together the produced vectors.
bool zip = 6;
}
message LoadDataResponse {
GlobalDataHandle data = 1;
ShapeProto data_shape = 2;
int64 available_rows = 3;
int64 rows_loaded = 4;
int64 nanoseconds = 5;
}
message GetShapeRequest {
GlobalDataHandle data = 1;
}
message GetShapeResponse {
ShapeProto shape = 1;
}
message UnpackRequest {
GlobalDataHandle data = 1;
}
message UnpackResponse {
repeated GlobalDataHandle tied_data = 1;
}