| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| syntax = "proto3"; |
| |
| package xla; |
| |
| option cc_enable_arenas = true; |
| |
| // Primitive types are the individual values that can be held in rectangular |
| // multidimensional arrays. A description of the rectangular multidimensional |
| // array dimensions / primitive type is given by Shape, below. |
| enum PrimitiveType { |
| // Invalid primitive type to serve as default. |
| PRIMITIVE_TYPE_INVALID = 0; |
| |
| // Predicates are two-state booleans. |
| PRED = 1; |
| |
| // Signed integral values of fixed width. |
| S8 = 2; |
| S16 = 3; |
| S32 = 4; |
| S64 = 5; |
| |
| // Unsigned integral values of fixed width. |
| U8 = 6; |
| U16 = 7; |
| U32 = 8; |
| U64 = 9; |
| |
| // Floating-point values of fixed width. |
| // |
| // Note: if f16s are not natively supported on the device, they will be |
| // converted to f16 from f32 at arbirary points in the computation. |
| F16 = 10; |
| F32 = 11; |
| |
| // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit |
| // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent |
| // and 7 bits for the mantissa. |
| BF16 = 16; |
| |
| F64 = 12; |
| |
| // Complex values of fixed width. |
| C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. |
| C128 = 18; // Paired F64 (real, imag), as in std::complex<double>. |
| |
| // A tuple is a polymorphic sequence; e.g. a shape that holds different |
| // sub-shapes. They are used for things like returning multiple values from a |
| // computation; e.g. a computation that returns weights and biases may have a |
| // signature that results in a tuple like (f32[784x2000], f32[2000]) |
| // |
| // If a shape proto has the tuple element type, it may not have any entries |
| // in the dimensions field. |
| TUPLE = 13; |
| |
| // An opaque type used for passing context-specific data to a custom |
| // operation. Shapes of this primitive type will have empty dimensions and |
| // tuple_shapes fields. |
| // |
| // (OPAQUE would be a better name for this identifier, but that conflicts with |
| // a macro defined in windows.h.) |
| OPAQUE_TYPE = 14; |
| |
| // A token type threaded between side-effecting operations. Shapes of this |
| // primitive type will have empty dimensions and tuple_shapes fields. |
| TOKEN = 17; |
| |
| // Next = 19 |
| } |
| |
| // Describes the padding configuration for Pad operation. The padding amount on |
| // both edges as well as between the elements are specified for each dimension. |
| message PaddingConfig { |
| // Describes the padding configuration for a dimension. |
| message PaddingConfigDimension { |
| // Padding amount on the low-end (next to the index 0). May be negative. |
| int64 edge_padding_low = 1; |
| |
| // Padding amount on the high-end (next to the highest index). May be |
| // negative. |
| int64 edge_padding_high = 2; |
| |
| // Padding amount between the elements. May not be negative. |
| int64 interior_padding = 3; |
| } |
| |
| // The padding configuration for all dimensions. |
| repeated PaddingConfigDimension dimensions = 1; |
| } |
| |
| // A format specifies the method used by a layout to store an array in memory. |
| enum Format { |
| // TODO(b/120869032): Rename this to FORMAT_NONE or something else which |
| // better corresponds to its meaning. |
| INVALID_FORMAT = 0; |
| // The default layout, with exactly one storage location per element. |
| DENSE = 1; |
| // A sparsely encoded layout, providing only the index/value pairs of non-zero |
| // elements. |
| SPARSE = 2; |
| } |
| |
| // Describes a tile used in tiling-based layout. Refer to |
| // g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for |
| // details about tiling-based layout. |
| message TileProto { |
| // Number of elements in each dimension of the tile. It's ordered from the |
| // most major dimension of the tile to the most minor dimension of the tile. |
| // The dimensions correspond to a suffix of the dimensions of the shape being |
| // tiled. |
| repeated int64 dimensions = 1; |
| } |
| |
| // A layout describes how the array is placed in (1D) memory space. This |
| // includes the minor-to-major ordering of dimensions within a shape. |
| // |
| // Clients must specify the layouts of input Literals to the |
| // computation. Layouts specified in interior operations which take Shapes (for |
| // example, Convert) are ignored. |
| // |
| // See the XLA documentation for more information on shapes and layouts. |
| // |
| // LINT.IfChange |
| message LayoutProto { |
| // The method used to store the data in memory. The format determines which of |
| // the other fields are used by the layout. |
| Format format = 4; |
| |
| // Sequence of dimension numbers, from minor (fastest varying index) to major |
| // (slowest varying index). This field is required. |
| repeated int64 minor_to_major = 1; |
| |
| reserved 2; |
| reserved "padded_dimensions"; |
| |
| reserved 3; |
| reserved "padding_value"; |
| |
| // The maximum number of elements that can be stored for SPARSE formats. This |
| // can be used to determine the maximum size in bytes of arrays stored in |
| // memory. This field must be unset unless the format is SPARSE. |
| int64 max_sparse_elements = 5; |
| |
| // A sequence of tiles, starting from the tile that's applied first to the |
| // Shape. |
| // |
| // TODO(b/119839262): implement tiling in each backend or add Unimplemented |
| // error. |
| repeated TileProto tiles = 6; |
| |
| // Bit size of each element. If the size is bigger than what the element |
| // type requires, the value is stored in the least significant |
| // bits and the additional most significant bits are filled with 0's. |
| // |
| // TODO(b/119839262): implement in each backend or add Unimplemented error. |
| int64 element_size_in_bits = 7; |
| |
| // Memory space where this array resides. The integer field is interpreted in |
| // a backend-specific manner. |
| int64 memory_space = 8; |
| |
| // Important: if any field is added, be sure to modify ShapeUtil::Equal() and |
| // LayoutUtil::Hash appropriately to account for the new field. |
| } |
| // LINT.ThenChange( \ |
| // https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ |
| // https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) |
| |
| // A shape describes the number of dimensions in the array, the size of each |
| // dimension, and the primitive component type. |
| // |
| // Tuples are a special case in that they have rank zero and have tuple_shapes |
| // defined. |
| // |
| // See the XLA documentation for more information on shapes and layouts. |
| // |
| // LINT.IfChange |
| message ShapeProto { |
| reserved 1; |
| reserved "rank"; |
| |
| // The element type for this shape. |
| PrimitiveType element_type = 2; |
| |
| // The size (number of elements) for each dimension, or an upper bound on the |
| // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 |
| // to N-1 for an N-dimensional array. The first element of 'dimensions' is the |
| // size of dimension 0, the second element is the size of dimension 1, and so |
| // forth. Empty list indicates a scalar. |
| // |
| // If the respective element in 'is_dimension_dynamic' is true then the value |
| // in this field represents an upper bound on the size of the dimension. |
| repeated int64 dimensions = 3; |
| |
| // For tuples only, the shapes of constituent shapes in the tuple sequence. |
| repeated ShapeProto tuple_shapes = 4; |
| |
| // The layout used to back this shape. |
| LayoutProto layout = 5; |
| |
| // For arrays, this indicates whether or not each dimension is |
| // dynamically-sized. The number of elements in this repeated field should be |
| // zero (indicating that no dimensions are dynamic) or equal to the number of |
| // elements in the 'dimensions' field. |
| repeated bool is_dynamic_dimension = 6; |
| |
| // Important: if any field is added, be sure to modify ShapeUtil::Equal(), |
| // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for |
| // the new field. |
| } |
| // LINT.ThenChange( \ |
| // https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) |
| |
| // Shape of the parameters and output of a computation (like a traditional |
| // function signature). |
| message ProgramShapeProto { |
| repeated ShapeProto parameters = 1; |
| ShapeProto result = 2; |
| repeated string parameter_names = 3; |
| } |
| |
| // Statistics of a computation. |
| message ComputationStats { |
| // The number of floating point operations in the computation. |
| double flop_count = 1; |
| |
| // The number of transcendental operations (e.g., exp) in the computation. |
| double transcendental_count = 2; |
| } |
| |
| // Symbolization metadata for HLO Instructions. |
| // |
| // This metadata is used for debugging XLA code generation, as well as |
| // performance profiling of XLA-generated executables. |
| message OpMetadata { |
| // The framework op name that generated this XLA op. |
| // |
| // Frameworks that build on top of XLA should mirror the names of their ops |
| // back to users by specifying the op_type. In this way, even if the |
| // framework's "ops" are implemented as multiple XLA HLO Ops, they can be |
| // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as |
| // multiple ops, then each op should have the op_type be "SoftMax".) |
| string op_type = 1; |
| // The user-specified name of the op. |
| // |
| // This name is often unique within a computation. Note: some frameworks |
| // add auto-generated names if the user does not provide one. |
| string op_name = 2; |
| // Indicate a file and line that this op is associated to in a user's program. |
| // |
| // e.g. it could be the file and line of user code that generated the op. |
| string source_file = 3; |
| int32 source_line = 4; |
| } |
| |
| // Profile data from the execution of a computation. |
| message ExecutionProfile { |
| // Whether the executable was read from the compilation cache. |
| bool compilation_cache_hit = 1; |
| |
| // The time in milliseconds spent to compile the computation. This only set if |
| // the executable was not read from the compilation cache |
| // (compilation_cache_hit == false). |
| int64 compile_time_ms = 2; |
| |
| // The number of cycles spent for the computation. This does not include the |
| // time taken for the data transfers between the host and the device. This is |
| // a target-dependent field and only used for debugging purposes. |
| int64 compute_cycle_count = 3; |
| |
| // The time in nanoseconds spent for the computation, without data transfer. |
| int64 compute_time_ns = 4; |
| |
| // The time in nanoseconds spent for the entire computation, including the |
| // result data transfer time. Current implementation does not spend any cycles |
| // for the input data transfer since the memory is initialized with the proper |
| // values before the execution. |
| int64 compute_and_transfer_time_ns = 5; |
| |
| // The size of the binary code in the executable. |
| int64 executable_size_in_bytes = 6; |
| |
| // Whether this profile was drawn from a cache of profiles instead of from |
| // execution on the hardware. |
| bool profile_cache_hit = 7; |
| } |
| |
| // Handle given to a user that represents an execution that the user launched |
| // asynchronously on the device. |
| message ExecutionHandle { |
| int64 handle = 1; |
| } |
| |
| // Handle given to a user that represents a globally accessible allocation. |
| // Contrast this against a ComputationDataHandle, which is not globally |
| // accessible, since it only exists within a specific computation. |
| message GlobalDataHandle { |
| int64 handle = 1; |
| } |
| |
| // Handle given to a user that represents a replicated virtual device. Each |
| // replicated device represents N physical devices for execution where N is the |
| // number of replicas. |
| message DeviceHandle { |
| int64 handle = 1; |
| |
| // The number of model-parallel virtual devices that communicate via XLA |
| // Send/Recv instructions. |
| int64 device_count = 2; |
| } |
| |
| // Handle given to a user to represent a channel between two computations |
| // via a Send and Recv instruction pair. Channels are unbuffered, so Send |
| // Send instructions will be blocked until the data is transferred. |
| message ChannelHandle { |
| int64 handle = 1; |
| enum ChannelType { |
| // Invalid primitive type to serve as default. |
| CHANNEL_TYPE_INVALID = 0; |
| |
| // A channel for sending data between devices. |
| DEVICE_TO_DEVICE = 1; |
| |
| // A channel for sending data from the device to the host. Can only be used |
| // with a Send operation. |
| DEVICE_TO_HOST = 2; |
| |
| // A channel for sending data from the host to the device. Can only be used |
| // with a Recv operation. |
| HOST_TO_DEVICE = 3; |
| } |
| ChannelType type = 2; |
| } |
| |
| // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which |
| // represents the device ids assigned to a set of replicated computations. |
| // See xla::DeviceAssignment class comment for more details. |
| message DeviceAssignmentProto { |
| int32 replica_count = 1; |
| int32 computation_count = 2; |
| |
| // Each logical computation runs on replica_count physical devices. |
| // ComputationDevice represents the device ids assinged to the replicas. |
| message ComputationDevice { |
| repeated int32 replica_device_ids = 1; |
| } |
| repeated ComputationDevice computation_devices = 3; |
| } |
| |
| // Literals are used when the server and client need to exchange materialized |
| // data / results. Literals are also used to describe constants used in |
| // computations. |
| // |
| // Transfers to/from the client are encoded in literal form, and the structure |
| // of the repeated fields is implied by the shape. |
| message LiteralProto { |
| ShapeProto shape = 1; |
| repeated bool preds = 2; |
| bytes s8s = 15; |
| bytes u8s = 3; |
| repeated int32 s32s = 4; |
| repeated int64 s64s = 5; |
| repeated uint32 u32s = 6; |
| repeated uint64 u64s = 7; |
| repeated float f32s = 8; |
| repeated double f64s = 9; |
| repeated float c64s = 12; // Stored as interleaved real, imag floats. |
| repeated double c128s = 18; // Stored as interleaved real, imag doubles. |
| repeated LiteralProto tuple_literals = 10; |
| // The F16s, BF16s, U16s and S16s are encoded in little endian byte order |
| bytes f16s = 11; |
| bytes bf16s = 13; |
| bytes u16s = 16; |
| bytes s16s = 17; |
| repeated int64 sparse_indices = 14; |
| // Next = 19 |
| } |
| |
| message WindowDimension { |
| // The size of the window in this dimension. For a rectangle, this would be |
| // the width or height. |
| int64 size = 1; |
| |
| // The stride at which the window moves across the base area in this |
| // dimension. In other words, this is the spacing between different |
| // positions of the window in this dimension. |
| int64 stride = 2; |
| |
| // If positive, means the amount of padding to add to the base area at the low |
| // end of this dimension; if negative, its negative means the number of |
| // elements removed from the low end of this dimension. For example, in the |
| // horizontal dimension of a rectangle, this would be the number of padding |
| // values to pad on the left, given that indices increase when going right. |
| // The actual padding value depends upon the context. Convolution pads with |
| // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's |
| // init value. |
| int64 padding_low = 3; |
| |
| // As padding_low, but on the high end of this dimension. For example, in the |
| // horizontal dimension of a rectangle, this would be the number of values to |
| // pad on the right, given that indices increase when going right. |
| int64 padding_high = 4; |
| |
| // Dilation factor of the sliding window in this dimension. A dilation factor |
| // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are |
| // implicitly placed between each kernel element. This value may not be less |
| // than 1. See documentation for convolution. |
| int64 window_dilation = 5; |
| |
| // Dilation factor of the base area in this dimension. A dilation factor of 1 |
| // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly |
| // placed between each base area element. This value may not be less than 1. |
| // See documentation for convolution. |
| int64 base_dilation = 6; |
| |
| // Window reversal means that this dimension was logically reversed before the |
| // operation. |
| bool window_reversal = 7; |
| } |
| |
| // Describes the windowing in an operation such as convolution. |
| // |
| // The window is moved across a base area and for each position of the |
| // window a computation is performed. The field below describes the |
| // window and the movement of the window across a base area. |
| message Window { |
| repeated WindowDimension dimensions = 1; |
| } |
| |
| // Describes the dimension numbers for a gather operation. |
| // |
| // See https://www.tensorflow.org/performance/xla/operation_semantics#gather for |
| // more details. |
| message GatherDimensionNumbers { |
| // "Window indices" is a term for a set of indices that index into the |
| // interior of a dynamic-slice from the input tensor, the starting indices for |
| // which were computed from output_gather_dims (see the operation semantic for |
| // how this is defined) and the start_indices tensor. |
| // |
| // The window indices for a specific output index Out is computed as: |
| // |
| // i = 0 |
| // for (k : [0, input_tensor_shape.rank)) |
| // window_indices[k] = |
| // if k in collapsed_slice_dims |
| // then 0 |
| // else Out[offset_dims[i++]] |
| repeated int64 offset_dims = 1; |
| repeated int64 collapsed_slice_dims = 2; |
| |
| // This is interpreted as a map from i to start_index_map[i]. It |
| // transforms the gather index looked up from the start_indices tensor into |
| // the starting index in the input space. |
| repeated int64 start_index_map = 3; |
| |
| // The dimension in the start_indices input that contains the starting |
| // indices. |
| int64 index_vector_dim = 4; |
| } |
| |
| // Describes the dimension numbers for a scatter operation. |
| // |
| // All the fields are similar to the corresponding fields in |
| // GatherDimensionNumbers. Differences are noted below. |
| message ScatterDimensionNumbers { |
| // The set of dimensions in the updates shape that are window dimensions. |
| repeated int64 update_window_dims = 1; |
| // The set of window dimensions that must be inserted into the updates shape. |
| repeated int64 inserted_window_dims = 2; |
| |
| repeated int64 scatter_dims_to_operand_dims = 3; |
| int64 index_vector_dim = 4; |
| } |
| |
| message ConvolutionDimensionNumbers { |
| // The number of the dimension that represents batch in the input. |
| int64 input_batch_dimension = 7; |
| |
| // The number of the dimension that represents features in the input. |
| int64 input_feature_dimension = 8; |
| |
| // The dimension numbers for the spatial dimensions that the window |
| // moves through in the input. |
| repeated int64 input_spatial_dimensions = 11; |
| |
| // The number of the dimension that represents input features in the |
| // convolutional kernel (rhs). |
| int64 kernel_input_feature_dimension = 3; |
| |
| // The number of the dimension that represents output features in |
| // the convolutional kernel (rhs). |
| int64 kernel_output_feature_dimension = 4; |
| |
| // The dimension numbers for the spatial dimensions that the window |
| // moves through in the kernel (rhs). window.strides(0) is the |
| // stride in the kernel_spatial_dimensions(0) dimension. |
| repeated int64 kernel_spatial_dimensions = 6; |
| |
| // The number of the dimension that represents batch in the output. |
| int64 output_batch_dimension = 9; |
| |
| // The number of the dimension that represents features in the output. |
| int64 output_feature_dimension = 10; |
| |
| // The dimension numbers for the spatial dimensions that the window |
| // moves through in the output. |
| repeated int64 output_spatial_dimensions = 12; |
| |
| // Next = 13 |
| } |
| |
| enum FftType { |
| FFT = 0; // Forward FFT; complex in, complex out. |
| IFFT = 1; // Inverse FFT; complex in, complex out. |
| RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out |
| IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, |
| // fft_length real out |
| } |
| |
| message DotDimensionNumbers { |
| // The dimension numbers that represent the 'lhs' contracting dimensions. |
| repeated int64 lhs_contracting_dimensions = 1; |
| // The dimension numbers that represent the 'rhs' contracting dimensions. |
| repeated int64 rhs_contracting_dimensions = 2; |
| // The dimension numbers that represent the 'lhs' batch dimensions. |
| repeated int64 lhs_batch_dimensions = 3; |
| // The dimension numbers that represent the 'rhs' batch dimensions. |
| repeated int64 rhs_batch_dimensions = 4; |
| } |
| |
| enum RandomDistribution { |
| RNG_INVALID = 0; |
| |
| // Creates a uniform-distribution-generated random number on the semi-open |
| // interval [parameter[0], parameter[1]). |
| RNG_UNIFORM = 1; |
| |
| // Creates a normal-distribution-generated random number with mean |
| // parameter[0] and standard deviation parameter[1]. |
| RNG_NORMAL = 2; |
| |
| // Next: 4 |
| } |
| |
| message TriangularSolveOptions { |
| // If true, solves ax = b. If false, solves xa = b. |
| bool left_side = 1; |
| |
| // If true, 'a' is lower triangular. If false, 'a' is upper triangular. |
| bool lower = 2; |
| |
| // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. |
| bool unit_diagonal = 3; |
| |
| // Should we transpose or use the adjoint of 'a'? |
| enum Transpose { |
| TRANSPOSE_INVALID = 0; |
| NO_TRANSPOSE = 1; // Don't transpose 'a'. |
| TRANSPOSE = 2; // Transpose 'a'. |
| ADJOINT = 3; // Complex conjugate and transpose 'a'. |
| }; |
| Transpose transpose_a = 4; |
| } |
| |
| message CholeskyOptions { |
| // If true, uses the lower triangle of `a`. If false, uses the upper triangle |
| // of `a`. |
| bool lower = 1; |
| } |
| |
| message OpSharding { |
| enum Type { |
| // This sharding is replicated across all devices (implies maximal, |
| // all other fields are unused). |
| REPLICATED = 0; |
| // This sharding is maximal - one device runs the entire operation. |
| MAXIMAL = 1; |
| // This sharding is a tuple - only the tuple_shardings field is valid. |
| TUPLE = 2; |
| // None of the above; tile_shape and tile_assignment are both used. |
| OTHER = 3; |
| } |
| Type type = 1; |
| // The shape of the sharded tile. |
| ShapeProto tile_shape = 2; |
| // The shape of the tile assignment tensor - this must be the same rank as |
| // tile_shape and the product of its dimensions must equal |
| // tile_assignment_devices.size(). |
| repeated int64 tile_assignment_dimensions = 3; |
| // Flattened list of device IDs. The order of flattening is the same as used |
| // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). |
| repeated int64 tile_assignment_devices = 4; |
| // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, |
| // in pre-order. The tuple shape could be nested; here we store just a |
| // flattened list of all leaves in the tuple shape. Note that the tuple shape |
| // is not stored here; shardings do not store the shapes to which they are |
| // applied, this is inferred from the instruction this sharding gets attached |
| // to. |
| repeated OpSharding tuple_shardings = 5; |
| } |
| |
| // Describes the replica groups in a cross replica op (e.g., all-reduce and |
| // all-to-all). |
| message ReplicaGroup { |
| // The ids of the replicas that belongs to the same group. The ordering of the |
| // ids matters in some ops (e.g., all-to-all). |
| repeated int64 replica_ids = 1; |
| } |
| |
| // Describes the source target pair in the collective permute op. |
| message SourceTarget { |
| int64 source = 1; |
| int64 target = 2; |
| } |
| |
| // Used to indicate the precision configuration. It has backend specific |
| // meaning. |
| message PrecisionConfig { |
| enum Precision { |
| DEFAULT = 0; |
| HIGH = 1; |
| HIGHEST = 2; |
| |
| // Next: 3 |
| } |
| repeated Precision operand_precision = 1; |
| |
| // Next: 2 |
| } |
| |
| // Describes whether all data-parallelism replicas will receive the same |
| // parameter data at each buffer. |
| message ParameterReplication { |
| // A list of boolean values for the flattened leaf buffers. Each value |
| // indicates whether the corresponding leaf buffer is replicated. |
| // |
| // If this field is empty, it means no buffer is replicated. Otherwise, the |
| // number of elements in this field must match the number of leaf buffers in |
| // the HLO instruction's shape. |
| repeated bool replicated_at_leaf_buffers = 1; |
| } |
| |
| // A backend-config for kWhile loops that stores the loop's trip count, if it is |
| // known. |
| // |
| // This is useful for backends that can implement a `for i in 0..N` loop more |
| // efficiently than a `while` loop. For example, on GPUs, we can implement a |
| // `for i in 0..N` loop by enqueueing the kernels for the loop body N times, |
| // whereas implementing a `while` loop requires a host-device sync on each |
| // iteration. |
| message WhileLoopBackendConfig { |
| message KnownTripCount { |
| int64 n = 1; |
| } |
| // This indirection lets us distinguish between known-trip-count == 0 and |
| // unknown-trip-count. |
| KnownTripCount known_trip_count = 1; |
| } |