| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This proto file defines messages which represent the HLO module. This is a |
| // full fidelity serialization of the c++ HLO constructs. |
| // |
| // Many of the protos below are simple 1-to-1 serializations of the |
| // corresponding C++ classes, e.g., HloModule, HloComputation, and |
| // HloInstruction. |
| // |
| // FIELD NAMES ARE IMPORTANT |
| // |
| // Unlike most protos, you can't safely change the names of fields, even if you |
| // keep the numeric ids the same. This is because we sometimes serialize these |
| // protos as JSON, which includes the field names in the serialization. |
| |
| syntax = "proto3"; |
| |
| package xla; |
| |
| import "tensorflow/compiler/xla/xla_data.proto"; |
| |
| option cc_enable_arenas = true; |
| |
| // Serialization of HloInstruction. |
| // Next ID: 75 |
| message HloInstructionProto { |
| reserved 10; |
| reserved "parameter_name"; |
| reserved 12; |
| reserved "fused_instructions_computation"; |
| reserved 4; |
| reserved "operand_names"; |
| reserved 5; |
| reserved "control_predecessor_names"; |
| reserved 6; |
| reserved "called_computation_names"; |
| reserved 44; |
| reserved "replica_group_ids"; |
| // Use backend_config instead for custom_call_opaque. |
| reserved 53; |
| reserved "custom_call_opaque"; |
| // Use backend_config instead for all_reduce_barrier. |
| reserved 46; |
| reserved "all_reduce_barrier"; |
| |
| string name = 1; |
| string opcode = 2; |
| xla.ShapeProto shape = 3; |
| |
| xla.OpMetadata metadata = 7; |
| |
| // Literal, only present for kConstant. |
| xla.LiteralProto literal = 8; |
| |
| // Parameter number is only present for kParameter. |
| int64 parameter_number = 9; |
| |
| // Fusion state, only present for kFusion. |
| string fusion_kind = 11; |
| |
| // Index for kGetTupleElement. |
| int64 tuple_index = 13; |
| |
| // Dimensions present for some operations that require reshaping or |
| // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. |
| repeated int64 dimensions = 14; |
| |
| // Describes the window in a windowed operation such as convolution. |
| xla.Window window = 15; |
| |
| // Describes the dimension numbers used for a convolution. |
| xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; |
| |
| // The number of feature groups. Used for a convolution. Must be a divisor of |
| // the input feature dimension and output feature dimension. If not specified, |
| // it will use a default value of 1. |
| int64 feature_group_count = 50; |
| |
| int64 batch_group_count = 58; |
| |
| // Describes the [begin, end) index range and stride for slices. |
| message SliceDimensions { |
| int64 start = 1; |
| int64 limit = 2; |
| int64 stride = 3; |
| } |
| repeated SliceDimensions slice_dimensions = 17; |
| |
| // The bit sizes for a reduce-precision operation. |
| int32 exponent_bits = 18; |
| int32 mantissa_bits = 19; |
| |
| // Describes the [start, start + size) range size for a dynamic slice |
| // ('start' is specified dynamically in the second operand of the operation). |
| repeated int64 dynamic_slice_sizes = 20; |
| |
| // The padding configuration that describes the edge padding and interior |
| // padding of this pad instruction. Only set for pad instructions. |
| xla.PaddingConfig padding_config = 21; |
| |
| // Outfeed configuration information, only present for kOutfeed. |
| bytes outfeed_config = 22; |
| |
| // The distribution requested for random number generation. |
| // Only present for kRng. |
| xla.RandomDistribution distribution = 23; |
| |
| // A small float number added to the variance to avoid divide-by-zero error. |
| // Only present for kBatchNormTraining. |
| float epsilon = 24; |
| |
| // An integer value representing the index of the feature dimension. |
| // Only present for kBatchNormTraining. |
| int64 feature_index = 25; |
| |
| // Represents a unique identifier for each Send/Recv instruction pair or |
| // optionally for collective instructions (AllReduce, CollectivePermute, |
| // AllToAll). Non-positive channel_id is equivalent to no channel id. |
| int64 channel_id = 26; |
| |
| // The string representation of the infeed configuration. |
| bytes infeed_config = 27; |
| |
| // Name of a external target (eg, global symbol) to call, only present for |
| // kCustomCall. |
| string custom_call_target = 28; |
| |
| // Shape of outfeed request. |
| xla.ShapeProto outfeed_shape = 29; |
| |
| // Describes the dimension numbers used for a dot operation |
| xla.DotDimensionNumbers dot_dimension_numbers = 30; |
| |
| // FFT type (FFT, IFFT, etc). |
| xla.FftType fft_type = 31; |
| |
| // FFT length. |
| repeated int64 fft_length = 32; |
| |
| // Comparison direction only used for kCompare. |
| string comparison_direction = 63; |
| |
| // Gather dimension numbers. |
| xla.GatherDimensionNumbers gather_dimension_numbers = 33; |
| repeated int64 gather_slice_sizes = 34; |
| |
| // Compute Host. |
| string channel_name = 41; |
| int64 cost_estimate_ns = 42; |
| |
| // The id of this instruction. |
| int64 id = 35; |
| |
| repeated int64 operand_ids = 36; |
| repeated int64 control_predecessor_ids = 37; |
| repeated int64 called_computation_ids = 38; |
| |
| xla.OpSharding sharding = 40; |
| |
| // Backend configuration for the instruction. Has backend-specific meaning. |
| bytes backend_config = 43; |
| |
| // Cross replica op fields. |
| repeated ReplicaGroup replica_groups = 49; |
| // Deprecated, but keeping it for backward compatibility. Use channel_id. |
| // Non-positive all_reduce_id is equivalent to no all_reduce_id. |
| int64 all_reduce_id = 45 [deprecated = true]; |
| |
| // If true, interprets ids in ReplicaGroup as global device ids, which is |
| // a linearized id of `replica_id * partition_count + partition_id`. |
| bool use_global_device_ids = 71; |
| |
| // Whether this Send/Recv instruction transfers data to/from the host. Only |
| // present for Send and Recv instructions and their SendDone and RecvDone |
| // partners. |
| bool is_host_transfer = 47; |
| |
| // Whether this Sort instruction should be stable. |
| bool is_stable = 60; |
| |
| xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; |
| |
| // Precision configuration for the instruction. Has backend-specific meaning. |
| xla.PrecisionConfig precision_config = 51; |
| |
| // Collective permute field. |
| repeated SourceTarget source_target_pairs = 52; |
| |
| // Sharding for kDomain instructions. |
| xla.OpSharding domain_entry_sharding = 54; |
| xla.OpSharding domain_exit_sharding = 55; |
| |
| // For custom call this indicates that the layouts are constrained. If |
| // constrain_layout is true then the 'shape' field must contain a layout, and |
| // 'operand_shapes_with_layout' must contain a shape with layout for each |
| // operand. |
| bool constrain_layout = 56; |
| repeated xla.ShapeProto operand_shapes_with_layout = 57; |
| |
| // Options for TriangularSolve |
| xla.TriangularSolveOptions triangular_solve_options = 59; |
| |
| // Options for Cholesky |
| xla.CholeskyOptions cholesky_options = 62; |
| |
| // Describes how parameters behave with regards to replicas. |
| xla.ParameterReplication parameter_replication = 61; |
| |
| // If set, the given instruction is run in parallel on e.g. multiple CPU |
| // cores. The outermost dimension gets split up into |
| // outer_dimension_partitions[0] pieces, the next-outermost dim gets split |
| // into outer_dimension_partitions[1] pieces, etc. |
| // |
| // It's illegal to partition a dimension into more shards than there are |
| // elements in that dimension. |
| repeated int64 outer_dimension_partitions = 64; |
| |
| // Whether the kCustomCall instruction has side-effects, only present for |
| // kCustomCall. |
| bool custom_call_has_side_effect = 65; |
| |
| // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing |
| // buffers between output and operands for kCustomCall. |
| repeated xla.CustomCallOutputOperandAliasing |
| custom_call_output_operand_aliasing = 74; |
| |
| // The delta value for kRngGetAndUpdateState. |
| int64 delta = 66; |
| |
| // Specifies if the gather/scatter indices are guaranteed to be sorted by the |
| // caller. |
| bool indices_are_sorted = 67; |
| |
| // Frontend attributes to pass to the XLA backend. |
| xla.FrontendAttributes frontend_attributes = 68; |
| |
| // Specifies if all elements updated are guaranteed to be unique by |
| // the caller. |
| bool unique_indices = 69; |
| |
| // RNG algorithm used by kRngBitGenerator. |
| xla.RandomAlgorithm rng_algorithm = 70; |
| |
| // The comparison type used for kCompare. |
| string comparison_type = 72; |
| |
| // Specifies if this is a cross-program-prefetch, used by kCopyStart. |
| bool is_cross_program_prefetch = 73; |
| } |
| |
| // Serialization of HloComputation. |
| message HloComputationProto { |
| reserved 3; |
| reserved "root_name"; |
| |
| string name = 1; |
| |
| // The array of instructions is always in a valid dependency order, where |
| // operands appear before their users. |
| repeated HloInstructionProto instructions = 2; |
| |
| // The program shape (with layout) of this computation. |
| |
| xla.ProgramShapeProto program_shape = 4; |
| |
| // The id of this computation. |
| int64 id = 5; |
| |
| // The id of the root of the computation. |
| int64 root_id = 6; |
| } |
| |
| // Serialization of an HLO schedule. An HLO schedule contains a total order of |
| // instructions for each non-fusion computation in the module. |
| message HloScheduleProto { |
| message InstructionSequence { |
| repeated int64 instruction_ids = 1; |
| } |
| |
| // Map from computation id to sequence. |
| map<int64, InstructionSequence> sequences = 1; |
| } |
| |
| enum Kind { |
| // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 |
| // behavior and missing has_*() APIs. |
| UNDEFINED_ALIAS = 0; |
| // The buffers may or may not alias at runtime. |
| MAY_ALIAS = 1; |
| // The buffers must alias at runtime. |
| MUST_ALIAS = 2; |
| } |
| |
| message HloInputOutputAliasProto { |
| // The following proto describes a pair of aliased an input |
| // (described by parameter number and a ShapeIndex of the parameter) |
| // and an output (described by a ShapeIndex of the root |
| // instruction). For example: |
| // |
| // entry = { |
| // output_shape_index={1}, |
| // parameter_number=0, |
| // parameter_shape_index={1, 2}, |
| // } |
| // |
| // This entry indicates that the first paremter's {1, 2} element is |
| // aliased with the {1} element of the root instruction. |
| message AliasEntryProto { |
| // ShapeIndex of the root hlo. |
| repeated int64 output_shape_index = 1; |
| // Number of the parameter in entry computation. |
| int64 parameter_number = 2; |
| // ShapeIndex of the parameter instruction. |
| repeated int64 parameter_shape_index = 3; |
| // The kind of alias to be setup. |
| Kind kind = 4; |
| } |
| |
| repeated AliasEntryProto entries = 1; |
| } |
| |
| message DynamicParameterBindingProto { |
| // A list of bindings which indicates that the `target_dim_num` in |
| // the subshape `target_param_index` of parameter `target_param_num` |
| // is a dynamic dimension and its real dynamic size is represented |
| // by `dynamic_param_index` in parameter `dynamic_param_num`. |
| // |
| // As an example, imagine we have a program: |
| // |
| // ENTRY main { |
| // a = f32[] parameter(0) |
| // b = f32[10] parameter(1) |
| // ROOT root = (f32[], f32[10]) tuple(%a, %b) |
| // } |
| // |
| // Let's say 'b' (param index 1) is a dynamic shape whose input has |
| // an upperbound of 10 and real size is determined at runtime.'a' |
| // represents the real size of b's first dimension. |
| // |
| // In this case, the fields are set in the following way: |
| // dynamic_param_num = 1 |
| // dynamic_param_index = {} |
| // target_param_num = 0 |
| // target_param_index = {} |
| // target_param_dim = 0 |
| message Binding { |
| int64 dynamic_param_num = 1; |
| repeated int64 dynamic_param_index = 2; |
| int64 target_param_num = 3; |
| repeated int64 target_param_index = 4; |
| int64 target_param_dim_num = 5; |
| } |
| |
| repeated Binding entries = 1; |
| } |
| |
| message CrossProgramPrefetch { |
| int64 parameter = 1; |
| repeated int64 index = 2; |
| } |
| |
| // Serialization of HloModule. |
| message HloModuleProto { |
| string name = 1; |
| string entry_computation_name = 2; |
| int64 entry_computation_id = 6; |
| |
| // The array of computations is always in a valid dependency order, where |
| // callees appear before their callers. |
| repeated HloComputationProto computations = 3; |
| |
| // The host program shape (with layout) of the entry computation. |
| xla.ProgramShapeProto host_program_shape = 4; |
| |
| // The id of this module. |
| int64 id = 5; |
| |
| // The schedule for this module. |
| HloScheduleProto schedule = 7; |
| |
| // Describes alias information between inputs and outputs. |
| HloInputOutputAliasProto input_output_alias = 8; |
| |
| DynamicParameterBindingProto dynamic_parameter_binding = 9; |
| |
| repeated CrossProgramPrefetch cross_program_prefetches = 10; |
| } |
| |
| // Serialization of LogicalBuffer. |
| message LogicalBufferProto { |
| // Location represents an instruction and its shape index, which uniquely |
| // identifies a point where a buffer is needed. |
| message Location { |
| // NOTE: module_name isn't necessary, since all LogicalBuffers are |
| // associated with a single HloModule. |
| string computation_name = 1; |
| string instruction_name = 2; |
| repeated int64 shape_index = 3; |
| } |
| |
| int64 id = 1; |
| int64 size = 2; |
| |
| // The location where the buffer is defined. |
| Location defined_at = 3; |
| |
| int64 color = 4; |
| } |
| |
| // Serialization of BufferAllocation. |
| message BufferAllocationProto { |
| // Assigned represents a single LogicalBuffer that is assigned to this |
| // BufferAllocation. |
| message Assigned { |
| int64 logical_buffer_id = 1; |
| int64 offset = 2; |
| int64 size = 3; |
| } |
| |
| int64 index = 1; |
| int64 size = 2; |
| bool is_thread_local = 3; |
| bool is_tuple = 11; |
| bool is_entry_computation_parameter = 5; |
| bool is_constant = 12; |
| int64 parameter_number = 6; |
| repeated int64 parameter_shape_index = 10; |
| bool maybe_live_out = 7; |
| int64 color = 8; |
| repeated Assigned assigned = 9; |
| } |
| |
| // A trace of a HeapSimulator run. |
| message HeapSimulatorTrace { |
| // The trace includes a list of events, where each event describes one action |
| // performed by the heap simulator. |
| message Event { |
| enum Kind { |
| ALLOC = 0; // A memory region was allocated for the buffer. |
| FREE = 1; // A memory region was freed for the buffer. |
| |
| // A buffer was shared with another (canonical) buffer. This is similar to |
| // ALLOC, except that instead of allocating a new region of memory, the |
| // memory region of the canonical buffer is directly re-used. Multiple |
| // buffers may share with the same canonical buffer. The lifetime of the |
| // canonical buffer is extended to the union of all lifetimes. |
| SHARE_WITH = 2; |
| } |
| Kind kind = 1; |
| |
| // The id of the LogicalBuffer that the event applies to. |
| int64 buffer_id = 2; |
| |
| // The HloInstruction that the simulation was processing that caused this |
| // event to occur, identified by its computation and instruction name. E.g. |
| // buffers defined by instruction A are allocated when processing A. |
| string computation_name = 3; |
| string instruction_name = 4; |
| |
| // The id of the canonical LogicalBuffer that the buffer shares with. Only |
| // set for SHARE_WITH events. |
| int64 share_with_canonical_id = 5; |
| } |
| repeated Event events = 1; |
| bool whole_module_simulation = 2; |
| } |
| |
| // An abstraction representing a set of HLO module built to run concurrently |
| // across different devices. |
| message HloModuleGroupProto { |
| string name = 1; |
| repeated HloModuleProto hlo_modules = 2; |
| } |
| |
| // Serialization of BufferAssignment. |
| message BufferAssignmentProto { |
| // Alias represents a source LogicalBuffer, and the buffer location that |
| // aliases it. |
| message BufferAlias { |
| int64 source_buffer_id = 1; |
| LogicalBufferProto.Location location = 2; |
| } |
| |
| repeated LogicalBufferProto logical_buffers = 1; |
| repeated BufferAlias buffer_aliases = 2; |
| repeated BufferAllocationProto buffer_allocations = 3; |
| repeated HeapSimulatorTrace heap_simulator_traces = 4; |
| } |
| |
| // Grouping message that contains all of the information above. |
| message HloProto { |
| reserved 2; |
| reserved "hlo_ordering"; |
| |
| HloModuleProto hlo_module = 1; |
| BufferAssignmentProto buffer_assignment = 3; |
| } |
| |
| // Encapsulates HloProto together with the arguments, result, and |
| // execution_platform. This message is used for purposes such as |
| // analysis/replay/file-storage. |
| message HloSnapshot { |
| // The hlo graph. |
| HloProto hlo = 1; |
| |
| // The arguments passed to the graph. |
| repeated LiteralProto arguments = 2; |
| |
| // The result of the graph. |
| LiteralProto result = 3; |
| |
| // The name of the platform used to run the graph. |
| string execution_platform = 4; |
| } |