| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // All HloInstruction subclasses are put in this file. |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |
| |
| #include <functional> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace xla { |
| |
| // Base class for instructions with a dimensions vector. |
| class HloDimensionsInstruction : public HloInstruction { |
| public: |
| HloDimensionsInstruction(HloOpcode opcode, const Shape& shape, |
| absl::Span<const int64_t> dimensions) |
| : HloInstruction(opcode, shape), |
| dimensions_(dimensions.begin(), dimensions.end()) {} |
| |
| absl::Span<const int64_t> dimensions() const override { return dimensions_; } |
| |
| std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; } |
| |
| HloInstructionProto ToProto() const override; |
| |
| protected: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| std::vector<int64_t> dimensions_; |
| }; |
| |
| class HloBatchNormInstruction : public HloInstruction { |
| public: |
| // Returns feature_index field associated with the instruction. The index |
| // represents the index of the feature dimension. |
| int64_t feature_index() const { return feature_index_; } |
| |
| // Returns a epsilon value associated with the instruction. The is a small |
| // number added to the variance to avoid divide-by-zero error. |
| float epsilon() const { return epsilon_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| protected: |
| explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, |
| HloInstruction* operand, |
| HloInstruction* scale, float epsilon, |
| int64_t feature_index); |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // A small float number added to the variance to avoid divide-by-zero error. |
| float epsilon_ = 0.0f; |
| |
| // An integer value representing the index of the feature dimension. |
| int64_t feature_index_ = -1; |
| }; |
| |
| class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { |
| public: |
| explicit HloBatchNormTrainingInstruction( |
| const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| HloInstruction* offset, float epsilon, int64_t feature_index); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { |
| public: |
| explicit HloBatchNormInferenceInstruction( |
| const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, |
| float epsilon, int64_t feature_index); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloBatchNormGradInstruction : public HloBatchNormInstruction { |
| public: |
| explicit HloBatchNormGradInstruction( |
| const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| HloInstruction* mean, HloInstruction* variance, |
| HloInstruction* grad_output, float epsilon, int64_t feature_index); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloFftInstruction : public HloInstruction { |
| public: |
| explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, |
| FftType fft_type, |
| absl::Span<const int64_t> fft_length); |
| FftType fft_type() const { return fft_type_; } |
| |
| const std::vector<int64_t>& fft_length() const { return fft_length_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Describes FFT type for an FFT instruction. |
| FftType fft_type_ = FftType::FFT; |
| |
| // Indicates the FFT length for an FFT instruction. |
| std::vector<int64_t> fft_length_; |
| }; |
| |
| class HloAsyncInstruction : public HloInstruction { |
| public: |
| HloAsyncInstruction( |
| HloOpcode opcode, const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* async_computation, |
| std::optional<int64_t> async_group_id = std::nullopt, |
| std::optional<std::string> async_thread_name = std::nullopt); |
| HloAsyncInstruction( |
| HloOpcode opcode, const Shape& shape, HloInstruction* operand, |
| HloComputation* async_computation, |
| std::optional<int64_t> async_group_id = std::nullopt, |
| std::optional<std::string> async_thread_name = std::nullopt); |
| |
| ~HloAsyncInstruction() override; |
| // When an async instruction is being destructed, remove it from the vector of |
| // pointers of its called computation, to avoid referencing freed memory. |
| void ClearAsyncComputationInstruction(); |
| |
| HloInstruction* async_wrapped_instruction() const; |
| HloOpcode async_wrapped_opcode() const; |
| |
| // Async group id is a unique id given to a group of async operations that |
| // consist of one async start, one async done, and zero or more async update |
| // operations. The async group participates in a single async operation. The |
| // async operation canonicalizer pass assigns async group ids. |
| std::optional<int64_t> async_group_id() const { return async_group_id_; } |
| |
| // Async thread name is a unique thread name for one or more async groups. |
| // Typically one HLO module contains a main thread as well as one or more |
| // parallel threads. Empty async_thread_name is equivalent to main thread. |
| std::optional<absl::string_view> async_thread_name() const { |
| return async_thread_name_; |
| } |
| void set_async_group_id(std::optional<int64_t> async_group_id); |
| void set_async_thread_name( |
| const std::optional<std::string>& async_thread_name); |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| std::optional<int64_t> async_group_id_; |
| std::optional<std::string> async_thread_name_; |
| }; |
| |
| class HloCopyStartInstruction : public HloInstruction { |
| public: |
| explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, |
| bool is_cross_program_prefetch); |
| |
| bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| bool is_cross_program_prefetch_; |
| }; |
| |
| class HloCompareInstruction : public HloInstruction { |
| public: |
| explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, |
| HloInstruction* rhs, |
| ComparisonDirection direction, |
| std::optional<Comparison::Type> type); |
| ComparisonDirection direction() const { return compare_.GetDirection(); } |
| ComparisonOrder order() const { return compare_.GetOrder(); } |
| Comparison::Type type() const { return compare_.GetType(); } |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| Comparison compare_; |
| }; |
| |
| class HloTriangularSolveInstruction : public HloInstruction { |
| public: |
| explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, |
| HloInstruction* b, |
| const TriangularSolveOptions& options); |
| const TriangularSolveOptions& triangular_solve_options() const { |
| return triangular_solve_options_; |
| } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| TriangularSolveOptions triangular_solve_options_; |
| }; |
| |
| class HloCholeskyInstruction : public HloInstruction { |
| public: |
| explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, |
| const CholeskyOptions& options); |
| const CholeskyOptions& cholesky_options() const { return cholesky_options_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| CholeskyOptions cholesky_options_; |
| }; |
| |
| // Class that represents instructions that synchronize and transfer data between |
| // partitioned devices. Send/Recv and collective instructions (AllReduce, |
| // AllToAll, CollectivePermute) belong to this instruction type. A group of |
| // instructions (of the same opcode) with the same channel_id communicate during |
| // execution. |
| class HloChannelInstruction : public HloInstruction { |
| public: |
| // Returns the channel id associated with the instruction. The id is |
| // shared between each Send/Recv pair or a group of collective instructions |
| // and is globally unique to identify each channel. |
| std::optional<int64_t> channel_id() const { return channel_id_; } |
| void set_channel_id(const std::optional<int64_t>& channel_id); |
| |
| // Whether this instruction is identical to `other` except for the values of |
| // channel IDs, as long as both have channel IDs or neither has a channel ID. |
| virtual bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const { |
| return channel_id_.has_value() == other.channel_id().has_value(); |
| } |
| |
| protected: |
| explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, |
| const std::optional<int64_t>& channel_id); |
| |
| HloInstructionProto ToProto() const override; |
| |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| |
| // Do not override IdenticalSlowPath(). Override |
| // IdenticalSlowPathIgnoringChannelIdValues() instead. |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const final; |
| |
| std::optional<int64_t> channel_id_; |
| }; |
| |
| class HloSendRecvInstruction : public HloChannelInstruction { |
| public: |
| // Returns whether this send/recv instruction sends data to/from the host. |
| bool is_host_transfer() const { return is_host_transfer_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| protected: |
| explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, |
| int64_t channel_id, bool is_host_transfer); |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Whether this send/recv instruction sends data to/from the host. |
| bool is_host_transfer_; |
| }; |
| |
| class HloSendInstruction : public HloSendRecvInstruction { |
| public: |
| explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, |
| int64_t channel_id, bool is_host_transfer); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloSendDoneInstruction : public HloSendRecvInstruction { |
| public: |
| explicit HloSendDoneInstruction(HloSendInstruction* operand, |
| bool is_host_transfer); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloRecvInstruction : public HloSendRecvInstruction { |
| public: |
| explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, |
| int64_t channel_id, bool is_host_transfer); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloRecvDoneInstruction : public HloSendRecvInstruction { |
| public: |
| explicit HloRecvDoneInstruction(HloRecvInstruction* operand, |
| bool is_host_transfer); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloCollectiveInstruction : public HloChannelInstruction { |
| public: |
| const std::vector<ReplicaGroup>& replica_groups() const { |
| return replica_groups_; |
| } |
| |
| // Returns true if the layout of the AllReduce is enforced by XLA client (as |
| // the layout set in the shape). The only reason for the client to set the |
| // layout is to separately compile computations that communicate with |
| // AllReduce. Since this field is only set `true` by the client, the compiler |
| // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set |
| // `false` for all other cases. |
| // |
| // When this is `true`, there may be communication endpoints outside the |
| // current compilation unit, so the compiler considers this AllReduce as |
| // side-effecting to disable compiler transformations. The compiler is free to |
| // transform unconstrained AllReduces differently across compilation units. |
| // It is an error for an HloModule to have a mix of constrained and |
| // unconstrained AllReduce instructions (checked by HloVerifier). |
| bool constrain_layout() const { return constrain_layout_; } |
| |
| protected: |
| explicit HloCollectiveInstruction( |
| HloOpcode opcode, const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, |
| const std::optional<int64_t>& channel_id); |
| |
| HloInstructionProto ToProto() const override; |
| |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| std::vector<ReplicaGroup> replica_groups_; |
| bool constrain_layout_; |
| }; |
| |
| class HloAllGatherInstruction : public HloCollectiveInstruction { |
| public: |
| explicit HloAllGatherInstruction( |
| HloOpcode opcode, const Shape& shape, |
| absl::Span<HloInstruction* const> operands, int64_t all_gather_dimension, |
| absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, |
| const std::optional<int64_t>& channel_id, bool use_global_device_ids); |
| // Same as HloAllReduceInstruction::use_global_device_ids. |
| bool use_global_device_ids() const { return use_global_device_ids_; } |
| |
| // The dimension on which data from different participants are concatenated. |
| int64_t all_gather_dimension() const { return all_gather_dimension_; } |
| absl::Span<const int64_t> dimensions() const override { |
| return absl::MakeConstSpan(&all_gather_dimension_, 1); |
| } |
| |
| void set_all_gather_dimension(int64_t dim) { all_gather_dimension_ = dim; } |
| |
| protected: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t all_gather_dimension_; |
| bool use_global_device_ids_; |
| }; |
| |
| // Base class for all-reduce and all-reduce scatter instructions. |
| class HloAllReduceInstructionBase : public HloCollectiveInstruction { |
| public: |
| explicit HloAllReduceInstructionBase( |
| HloOpcode opcode, const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* reduce_computation, |
| absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, |
| const std::optional<int64_t>& channel_id, bool use_global_device_ids); |
| |
| // Returns true if the ids in the ReplicaGroup config represent a global id of |
| // (replica_id * partition_count + partition_id) instead of a replica id. |
| // This enables more flexible grouping of devices if this all-reduce is both |
| // cross-partition and cross-replica. |
| // |
| // For example with 2 replicas and 4 partitions, |
| // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that |
| // group[0] = (0,0), (0,1), (1,0), (1,1) |
| // group[1] = (0,2), (0,3), (1,2), (1,3) |
| // where each pair is (replica_id, partition_id). |
| bool use_global_device_ids() const { return use_global_device_ids_; } |
| void set_use_global_device_ids(bool value) { use_global_device_ids_ = value; } |
| |
| protected: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| HloInstructionProto ToProto() const override; |
| |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| private: |
| bool use_global_device_ids_; |
| }; |
| |
| class HloAllReduceInstruction : public HloAllReduceInstructionBase { |
| public: |
| using HloAllReduceInstructionBase::HloAllReduceInstructionBase; |
| |
| // Returns true if the AllReduce does no communication, so it's equivalent |
| // to a mem copy. |
| bool IsNoop() const; |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloReduceScatterInstruction : public HloAllReduceInstructionBase { |
| public: |
| explicit HloReduceScatterInstruction( |
| const Shape& shape, absl::Span<HloInstruction* const> operands, |
| HloComputation* reduce_computation, |
| absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, |
| const std::optional<int64_t>& channel_id, bool use_global_device_ids, |
| int64_t scatter_dimension); |
| |
| // The dimension on which reduced data is scattered to different participants. |
| int64_t scatter_dimension() const { return scatter_dimension_; } |
| absl::Span<const int64_t> dimensions() const override { |
| return absl::MakeConstSpan(&scatter_dimension_, 1); |
| } |
| |
| protected: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t scatter_dimension_; |
| }; |
| |
| class HloAllToAllInstruction : public HloCollectiveInstruction { |
| public: |
| explicit HloAllToAllInstruction( |
| const Shape& shape, absl::Span<HloInstruction* const> operands, |
| absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, |
| const std::optional<int64_t>& channel_id, |
| const std::optional<int64_t>& split_dimension); |
| |
| // AllToAll can optionally take a split dimension, which means that this |
| // AllToAll takes a single (flattened) array operand and produces an array |
| // output (instead of taking a list of operands and producing a tuple). |
| // |
| // split_dimension specifies which dimension in the operand is split across |
| // devices in each replica_group, and also means the concatenated dimension |
| // on the output (i.e., input and the output shapes are the same). |
| std::optional<int64_t> split_dimension() const { return split_dimension_; } |
| void set_split_dimension(int64_t dim) { split_dimension_ = dim; } |
| |
| protected: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| std::optional<int64_t> split_dimension_; |
| }; |
| |
| class HloCollectivePermuteInstruction : public HloChannelInstruction { |
| public: |
| explicit HloCollectivePermuteInstruction( |
| HloOpcode opcode, const Shape& shape, HloInstruction* operand, |
| const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs, |
| const std::optional<int64_t>& channel_id); |
| |
| explicit HloCollectivePermuteInstruction( |
| HloOpcode opcode, const Shape& shape, HloInstruction* input, |
| HloInstruction* output, HloInstruction* input_start_indices, |
| HloInstruction* output_start_indices, |
| absl::Span<const std::pair<int64_t, int64_t>> source_target_pairs, |
| absl::Span<const std::vector<int64_t>> slice_sizes, |
| const std::optional<int64_t>& channel_id); |
| |
| const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs() const { |
| return source_target_pairs_; |
| } |
| |
| const std::vector<std::vector<int64_t>>& dynamic_slice_sizes_list() const { |
| return slice_sizes_; |
| } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPathIgnoringChannelIdValues( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| const std::vector<std::pair<int64_t, int64_t>> source_target_pairs_; |
| const std::vector<std::vector<int64_t>> slice_sizes_; |
| }; |
| |
| class HloReverseInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, |
| absl::Span<const int64_t> dimensions); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloConcatenateInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloConcatenateInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| int64_t dimension); |
| // Accessor for the dimension in which a concatenate HLO should occur. |
| int64_t concatenate_dimension() const override { |
| return HloInstruction::dimensions(0); |
| } |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloReduceInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloReduceInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> args, |
| absl::Span<const int64_t> dimensions_to_reduce, |
| HloComputation* reduce_computation); |
| |
| // Returns the number of input arrays (and, consequentially, the number of |
| // init values) this reduce has. |
| int64_t input_count() const { return operand_count() / 2; } |
| |
| // Returns the input tensors to be reduced. |
| absl::Span<HloInstruction* const> inputs() const { |
| return absl::MakeSpan(operands()).subspan(0, input_count()); |
| } |
| |
| // Returns the init values of the reduction. |
| absl::Span<HloInstruction* const> init_values() const { |
| return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); |
| } |
| |
| private: |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloSortInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloSortInstruction(const Shape& shape, int64_t dimension, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* compare, bool is_stable); |
| // Returns the sort dimension for this instruction |
| int64_t sort_dimension() const { return HloInstruction::dimensions(0); } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| // Returns the key operand to this instruction. |
| const HloInstruction* keys() const { return operand(0); } |
| HloInstruction* mutable_keys() { return mutable_operand(0); } |
| // Returns the number of value operands. |
| int64_t values_count() const { return operand_count() - 1; } |
| bool is_stable() const { return is_stable_; } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| bool is_stable_; |
| }; |
| |
| class HloTransposeInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, |
| absl::Span<const int64_t> dimensions); |
| // Returns whether this instruction does a rank-2 transposition. |
| bool IsRank2Transpose() const; |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloBroadcastInstruction : public HloDimensionsInstruction { |
| public: |
| explicit HloBroadcastInstruction( |
| const Shape& shape, HloInstruction* operand, |
| absl::Span<const int64_t> broadcast_dimension); |
| |
| private: |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| }; |
| |
| class HloDynamicReshapeInstruction : public HloInstruction { |
| public: |
| explicit HloDynamicReshapeInstruction( |
| const Shape& shape, HloInstruction* data_operand, |
| absl::Span<HloInstruction* const> dim_sizes); |
| |
| // Returns the input dim sizes dimensions, which is operands[1:] |
| absl::Span<HloInstruction* const> dim_sizes() const { |
| return absl::MakeSpan(operands()).subspan(1, operand_count()); |
| } |
| |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Returns the input dim size dimension, which is operands[1+i] |
| HloInstruction* dim_sizes(int64_t i) const { return operands()[i + 1]; } |
| }; |
| |
| class HloReshapeInstruction : public HloInstruction { |
| public: |
| explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, |
| int64_t inferred_dimension); |
| int64_t inferred_dimension() const { return inferred_dimension_; } |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| int64_t inferred_dimension_; |
| }; |
| |
| class HloMapInstruction : public HloInstruction { |
| public: |
| explicit HloMapInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* map_computation); |
| // Returns the dimension sizes or numbers associated with this instruction. |
| absl::Span<const int64_t> dimensions() const override { return dimensions_; } |
| |
| std::vector<int64_t>* mutable_dimensions() override { return &dimensions_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| bool IsElementwiseImpl( |
| const std::optional<int64_t>& operand_idx) const override; |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| std::vector<int64_t> dimensions_; |
| }; |
| |
| class HloSliceInstruction : public HloInstruction { |
| public: |
| explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, |
| absl::Span<const int64_t> start_indices, |
| absl::Span<const int64_t> limit_indices, |
| absl::Span<const int64_t> strides); |
| |
| HloInstructionProto ToProto() const override; |
| |
| // Returns the start index in the given dimension for a slice node. |
| int64_t slice_starts(int64_t dimension) const { |
| return slice_starts_[dimension]; |
| } |
| const std::vector<int64_t>& slice_starts() const { return slice_starts_; } |
| std::vector<int64_t>* mutable_slice_starts() { return &slice_starts_; } |
| |
| // Returns the (exclusive) limit index in the given dimension for a slice |
| // node. |
| int64_t slice_limits(int64_t dimension) const { |
| return slice_limits_[dimension]; |
| } |
| const std::vector<int64_t>& slice_limits() const { return slice_limits_; } |
| std::vector<int64_t>* mutable_slice_limits() { return &slice_limits_; } |
| |
| // Returns the stride in the given dimension for a slice node. |
| int64_t slice_strides(int64_t dimension) const { |
| return slice_strides_[dimension]; |
| } |
| const std::vector<int64_t>& slice_strides() const { return slice_strides_; } |
| std::vector<int64_t>* mutable_slice_strides() { return &slice_strides_; } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Describes the [begin, end) index range for a slice. |
| std::vector<int64_t> slice_starts_; |
| std::vector<int64_t> slice_limits_; |
| std::vector<int64_t> slice_strides_; |
| }; |
| |
| class HloConstantInstruction : public HloInstruction { |
| public: |
| explicit HloConstantInstruction(Literal literal); |
| explicit HloConstantInstruction(Literal literal, const Shape& shape); |
| // Used when the literal is too large and dropped. |
| explicit HloConstantInstruction(const Shape& shape); |
| // Returns the literal associated with this instruction. |
| const Literal& literal() const { return *literal_; } |
| // Returns the (mutable) literal associated with this instruction. |
| Literal* mutable_literal() { return &literal_.value(); } |
| // Returns whether there is literal associated with this instruction. |
| bool HasLiteral() const { return literal_.has_value(); } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| // Change the layout for an Constant Hlo instruction to match new_layout. For |
| // tuple shaped constants shape_index is the path to the internal array |
| // subshape whose layout needs to be changed. |
| void RelayoutConstant(const Layout& new_layout, |
| const ShapeIndex& shape_index = {}); |
| |
| private: |
| bool IsElementwiseImpl( |
| const std::optional<int64_t>& operand_idx) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::string OperandsToStringWithCanonicalNameMap( |
| const HloPrintOptions& options, |
| CanonicalNameMap* canonical_name_map) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| std::optional<Literal> literal_; |
| }; |
| |
| // Abstract class that represents an HLO instruction that "calls" a computation. |
| // Fusion and Call HLOs inherit from this class. |
| class HloCallableInstruction : public HloInstruction { |
| public: |
| HloCallableInstruction(HloOpcode opcode, const Shape& shape); |
| |
| HloCallableInstruction(HloOpcode opcode, const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* called_computation); |
| |
| ~HloCallableInstruction() override; |
| |
| // Adds a new operand to the callable instruction. |
| HloInstruction* AddCallOperand(HloInstruction* new_operand); |
| |
| // Appends (fuses) the given instruction into this callable instruction. |
| // instruction_to_append is cloned and the clone is placed in the callable |
| // instruction. The users of instruction_to_append will be redirected to this |
| // callable instruction. instruction_to_append is unchanged otherwise. When |
| // add_output is true, a clone of the instruction_to_append will be added as |
| // additional output resulting in a multi-output callable instruction. |
| HloInstruction* AppendInstructionIntoCalledComputation( |
| HloInstruction* instruction_to_append, bool add_output = false); |
| // Clones the given instruction_to_append and inserts the clone into this |
| // callable instruction. If add_output is true, a clone of |
| // instruction_to_append will be in the output of the this callable |
| // instruction (part of the tuple of the callable root). |
| HloInstruction* CloneAndAppendInstructionIntoCalledComputation( |
| HloInstruction* instruction_to_append, bool add_output = false); |
| |
| HloComputation* called_computation() const; |
| |
| HloInstruction* called_computation_root() const; |
| |
| // Recursively sets all nested called computation to have thread name as |
| // `thread_name`. if `skip_async_thread_name_overwrite` is true, skip |
| // overwrite async instruction and its comptuations thread name overwriting. |
| void RecursivelySetComputationsThreadName( |
| std::optional<std::string> thread_name, |
| bool skip_async_thread_name_overwrite); |
| |
| protected: |
| // Returns the default called computation name. |
| virtual std::string default_called_computation_name() const = 0; |
| }; |
| |
| class HloFusionInstruction : public HloCallableInstruction { |
| public: |
| explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, |
| HloInstruction* fused_root); |
| |
| explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* fusion_computation); |
| |
| ~HloFusionInstruction() override; |
| |
| void ClearCalledComputations() override; |
| |
| // When a fusion instruction is being destructed, clear the back pointer of |
| // its fusion computation, to avoid referencing freed memory. |
| void ClearFusionComputationInstruction(); |
| |
| std::string ToCategory() const override; |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| // Adds a new operand the fusion instruction. |
| HloInstruction* AddFusionOperand(HloInstruction* new_operand); |
| |
| // Merges the fused instructions from 'instruction_to_merge' into the |
| // fused instruction set of 'this', updating operands as necessary. |
| // |
| // Precondition: 'instruction_to_merge' must be an operand of 'this'. |
| void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); |
| |
| // Merges the fused instructions from instruction_to_merge into the fused |
| // instruction set of 'this' and generates multioutput fusion instructions. |
| // All the users of instruction_to_merge will be redirected to 'this' |
| // instruction. instruction_to_merge will be removed from its parent |
| // computation. |
| void MergeFusionInstructionIntoMultiOutput( |
| HloFusionInstruction* instruction_to_merge); |
| |
| // Fuses the given instruction in this fusion instruction. instruction_to_fuse |
| // is cloned and the clone is placed in the fusion |
| // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather |
| // than moved to cleanly handle the case where the instruction has a use |
| // outside the fusion instruction. Moving such an instruction into a fusion |
| // instruction would violate the single-result invariant of HLO instructions |
| // and significantly complicate code generation. |
| HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { |
| CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); |
| return AppendInstructionIntoCalledComputation(instruction_to_fuse); |
| } |
| |
| // Fuses the given instruction in this fusion instruction and generates a |
| // multioutput fusion instruction. A clone of the instruction_to_fuse will |
| // be part of the output of fusion instructions. The users of |
| // instruction_to_fuse will be redirected to this fusion instructions. |
| // instruction_to_fuse is unchanged otherwise. |
| HloInstruction* FuseInstructionIntoMultiOutput( |
| HloInstruction* instruction_to_fuse) { |
| return AppendInstructionIntoCalledComputation(instruction_to_fuse, |
| /*add_output=*/true); |
| } |
| |
| // Returns the computation for this fused instruction. |
| HloComputation* fused_instructions_computation() const; |
| |
| // Returns the root instruction of the fused expression contained within this |
| // fusion instruction. |
| HloInstruction* fused_expression_root() const; |
| |
| // Returns the list of fused instructions inside this fusion instruction. The |
| // returned type is a range of HloInstruction*s. |
| const tensorflow::gtl::iterator_range<UnwrappingIterator< |
| std::list<std::unique_ptr<HloInstruction>>::const_iterator>> |
| fused_instructions() const; |
| |
| const tensorflow::gtl::iterator_range< |
| UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> |
| fused_instructions(); |
| |
| // Gets the number of instructions inside this fusion instruction. |
| int64_t fused_instruction_count() const; |
| |
| // Returns the fused parameter instruction in this fusion instruction |
| // corresponding to the given parameter number. |
| HloInstruction* fused_parameter(int64_t parameter_number) const; |
| |
| // Returns the vector of fused parameters inside this fusion instruction. |
| const std::vector<HloInstruction*>& fused_parameters() const; |
| |
| // Returns true if this instruction is a fusion instruction that generates |
| // multiple outputs. |
| const bool IsMultiOutputFusion() const { |
| return fused_expression_root()->opcode() == HloOpcode::kTuple; |
| } |
| |
| FusionKind fusion_kind() const { return fusion_kind_; } |
| |
| void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } |
| |
| // If multiple operands are the same instruction, keeps only one of them. |
| Status DeduplicateFusionOperands(); |
| |
| protected: |
| std::string default_called_computation_name() const override { |
| return "fused_computation"; |
| } |
| |
| private: |
| bool IsElementwiseImpl( |
| const std::optional<int64_t>& operand_idx) const override; |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // The type of the fusion. Used by kFusion only. |
| FusionKind fusion_kind_; |
| }; |
| |
| class HloCallInstruction : public HloCallableInstruction { |
| public: |
| HloCallInstruction(const Shape& shape, |
| HloInstruction* called_computation_root); |
| |
| HloCallInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* called_computation); |
| |
| protected: |
| std::string default_called_computation_name() const override { |
| return "called_computation"; |
| } |
| }; |
| |
| class HloRngInstruction : public HloInstruction { |
| public: |
| explicit HloRngInstruction(const Shape& shape, |
| RandomDistribution distribution, |
| absl::Span<HloInstruction* const> parameters); |
| // Returns the random distribution for this rng node. |
| RandomDistribution random_distribution() const { return distribution_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| bool IsElementwiseImpl( |
| const std::optional<int64_t>& operand_idx) const override; |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // The distribution requested for random number generation. |
| RandomDistribution distribution_; |
| }; |
| |
| class HloParameterInstruction : public HloInstruction { |
| public: |
| explicit HloParameterInstruction(int64_t parameter_number, const Shape& shape, |
| const std::string& name); |
| int64_t parameter_number() const { return parameter_number_; } |
| |
| // Sets and gets the whether all replicas will receive the same parameter data |
| // for each leaf buffer in data parallelism. |
| void set_parameter_replicated_at_leaf_buffers( |
| absl::Span<const bool> parameter_replicated_at_leaf_buffers) { |
| CHECK_EQ(ShapeUtil::GetLeafCount(shape()), |
| parameter_replicated_at_leaf_buffers.size()); |
| parameter_replicated_at_leaf_buffers_.emplace( |
| parameter_replicated_at_leaf_buffers.begin(), |
| parameter_replicated_at_leaf_buffers.end()); |
| } |
| void set_parameter_replicated_at_leaf_buffers( |
| const std::vector<bool>& parameter_replicated_at_leaf_buffers) { |
| CHECK_EQ(ShapeUtil::GetLeafCount(shape()), |
| parameter_replicated_at_leaf_buffers.size()); |
| parameter_replicated_at_leaf_buffers_ = |
| parameter_replicated_at_leaf_buffers; |
| } |
| const std::optional<std::vector<bool>>& parameter_replicated_at_leaf_buffers() |
| const { |
| return parameter_replicated_at_leaf_buffers_; |
| } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::string OperandsToStringWithCanonicalNameMap( |
| const HloPrintOptions& options, |
| CanonicalNameMap* canonical_name_map) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t parameter_number_ = 0; |
| |
| // Specifies whether each buffer has the same parameter value on all replicas |
| // in data parallelism. |
| std::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_; |
| }; |
| |
| class HloGetTupleElementInstruction : public HloInstruction { |
| public: |
| explicit HloGetTupleElementInstruction(const Shape& shape, |
| HloInstruction* operand, |
| int64_t index); |
| // Returns the tuple index associated with this instruction. |
| int64_t tuple_index() const { return tuple_index_; } |
| // Sets the tuple index associated with this instruction. |
| void set_tuple_index(int64_t new_tuple_index) { |
| tuple_index_ = new_tuple_index; |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t tuple_index_ = -1; |
| }; |
| |
| class HloReducePrecisionInstruction : public HloInstruction { |
| public: |
| explicit HloReducePrecisionInstruction(const Shape& shape, |
| HloInstruction* operand, |
| const int exponent_bits, |
| const int mantissa_bits); |
| // Returns the number of exponent bits for a reduce-precision node. |
| int32_t exponent_bits() const { return exponent_bits_; } |
| // Returns the number of mantissa bits for a reduce-precision node. |
| int32_t mantissa_bits() const { return mantissa_bits_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // The bit sizes for a reduce-precision operation. |
| int32_t exponent_bits_ = 0; |
| int32_t mantissa_bits_ = 0; |
| }; |
| |
| class HloInfeedInstruction : public HloInstruction { |
| public: |
| explicit HloInfeedInstruction(const Shape& infeed_shape, |
| HloInstruction* token_operand, |
| const std::string& config); |
| // Returns the infeed configuration string. The infeed configuration includes |
| // any metadata needed for the backend compiler (e.g., infeed buffer address) |
| // and is target-dependent. |
| std::string infeed_config() const { return infeed_config_; } |
| void set_infeed_config(const std::string& config) { infeed_config_ = config; } |
| // Returns the shape of the data received by the infeed. This is not the same |
| // as the shape of the infeed instruction which produces a tuple containing |
| // the infeed data shape and a TOKEN. |
| const Shape& infeed_shape() const { |
| TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); |
| return ShapeUtil::GetSubshape(shape(), {0}); |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // The string representation of the infeed configuration. |
| std::string infeed_config_; |
| }; |
| |
| class HloOutfeedInstruction : public HloInstruction { |
| public: |
| explicit HloOutfeedInstruction(const Shape& outfeed_shape, |
| HloInstruction* operand, |
| HloInstruction* token_operand, |
| absl::string_view outfeed_config); |
| // Returns the shape for the Outfeed instruction. |
| const Shape& outfeed_shape() const { return outfeed_shape_; } |
| // Returns the mutable shape for the Outfeed instruction. |
| Shape* mutable_outfeed_shape() { return &outfeed_shape_; } |
| // Returns the config for the Outfeed instruction. |
| const std::string& outfeed_config() const { return outfeed_config_; } |
| void set_outfeed_config(const std::string& config) { |
| outfeed_config_ = config; |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Shape of outfeed request. |
| Shape outfeed_shape_; |
| // Outfeed configuration information, only present for kOutfeed. |
| std::string outfeed_config_; |
| }; |
| |
| class HloConvolutionInstruction : public HloInstruction { |
| public: |
| explicit HloConvolutionInstruction( |
| const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, |
| int64_t feature_group_count, int64_t batch_group_count, |
| const Window& window, |
| const ConvolutionDimensionNumbers& dimension_numbers, |
| const PrecisionConfig& precision_config); |
| const Window& window() const override { return window_; } |
| void set_window(const Window& window) override { window_ = window; } |
| const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { |
| return convolution_dimension_numbers_; |
| } |
| void set_convolution_dimension_numbers( |
| const ConvolutionDimensionNumbers& dnums) { |
| convolution_dimension_numbers_ = dnums; |
| } |
| // The number of feature groups. Must be a divisor of the input feature |
| // dimension and output feature dimension. |
| int64_t feature_group_count() const { return feature_group_count_; } |
| void set_feature_group_count(int64_t num_feature_groups) { |
| feature_group_count_ = num_feature_groups; |
| } |
| // The number of batch groups. Must be a divisor of the input batch dimension. |
| int64_t batch_group_count() const { return batch_group_count_; } |
| void set_batch_group_count(int64_t num_batch_groups) { |
| batch_group_count_ = num_batch_groups; |
| } |
| |
| // Returns the information used to tell the implementation information about |
| // what sort of precision is requested. The meaning of the field is backend |
| // specific. At the moment, it is only supported for kConvolution and kDot. |
| // Transformations on one kDot or kConvolution to another will preserve this |
| // information. Transformations to other HLOs will not preserve this |
| // information but it is presumed that the alternate lowering is strictly |
| // superior. |
| const PrecisionConfig& precision_config() const { return precision_config_; } |
| PrecisionConfig* mutable_precision_config() { return &precision_config_; } |
| |
| std::string ToCategory() const override; |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| // The number of feature groups. Must be a divisor of the input feature |
| // dimension and output feature dimension. |
| int64_t feature_group_count_; |
| // The number of batch groups. Must be a divisor of the input batch dimension. |
| int64_t batch_group_count_; |
| // Describes the window used for a convolution. |
| Window window_; |
| // Describes the dimension numbers used for a convolution. |
| ConvolutionDimensionNumbers convolution_dimension_numbers_; |
| // Information used to communicate to the implementation about the algorithm |
| // used to produce results. See the documentation on precision_config(). |
| PrecisionConfig precision_config_; |
| }; |
| |
| class HloReduceWindowInstruction : public HloInstruction { |
| public: |
| explicit HloReduceWindowInstruction(const Shape& shape, |
| HloInstruction* operand, |
| HloInstruction* init_value, |
| const Window& window, |
| HloComputation* reduce_computation); |
| explicit HloReduceWindowInstruction( |
| const Shape& shape, absl::Span<HloInstruction* const> operands, |
| absl::Span<HloInstruction* const> init_values, const Window& window, |
| HloComputation* reduce_computation); |
| const Window& window() const override { return window_; } |
| void set_window(const Window& window) override { window_ = window; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| // Returns the number of input arrays (and, consequentially, the number of |
| // init values) this reduce has. |
| int64_t input_count() const { return operand_count() / 2; } |
| // Returns the input tensors to be reduced. |
| absl::Span<HloInstruction* const> inputs() const { |
| return absl::MakeSpan(operands()).subspan(0, input_count()); |
| } |
| // Returns the init values of the reduction. |
| absl::Span<HloInstruction* const> init_values() const { |
| return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); |
| } |
| // Returns the shapes of input tensors to be reduced. |
| absl::InlinedVector<const Shape*, 2> input_shapes() const { |
| absl::InlinedVector<const Shape*, 2> shapes; |
| for (const auto* op : inputs()) { |
| VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n"; |
| shapes.push_back(&op->shape()); |
| VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n"; |
| } |
| return shapes; |
| } |
| // Returns the init values of the reduction. |
| absl::InlinedVector<const Shape*, 2> init_value_shapes() const { |
| absl::InlinedVector<const Shape*, 2> shapes; |
| for (const auto* op : init_values()) { |
| shapes.push_back(&op->shape()); |
| } |
| return shapes; |
| } |
| // Returns the shapes of the reduced output tensors. |
| absl::InlinedVector<const Shape*, 2> output_shapes() const { |
| absl::InlinedVector<const Shape*, 2> shapes; |
| if (shape().IsArray()) { |
| shapes.push_back(&shape()); |
| } else { |
| for (const Shape& tuple_element_shape : shape().tuple_shapes()) { |
| shapes.push_back(&tuple_element_shape); |
| } |
| } |
| return shapes; |
| } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| Window window_; |
| }; |
| |
| class HloSelectAndScatterInstruction : public HloInstruction { |
| public: |
| explicit HloSelectAndScatterInstruction( |
| const Shape& shape, HloInstruction* operand, HloComputation* select, |
| const Window& window, HloInstruction* source, HloInstruction* init_value, |
| HloComputation* scatter); |
| const Window& window() const override { return window_; } |
| void set_window(const Window& window) override { window_ = window; } |
| // Gets/sets the select or scatter HloComputation for SelectAndScatter. The |
| // setters should only be called by HloModule or HloComputation methods. |
| HloComputation* select() const { |
| return called_computations()[kSelectComputationIndex]; |
| } |
| |
| HloComputation* scatter() const { |
| return called_computations()[kScatterComputationIndex]; |
| } |
| |
| void set_select(HloComputation* computation) { |
| // Don't allow changing the computation for fused instructions so we don't |
| // have to recompute called_instructions for the entire fusion instruction. |
| CHECK(!IsFused()); |
| set_called_computation(kSelectComputationIndex, computation); |
| } |
| |
| void set_scatter(HloComputation* computation) { |
| // Don't allow changing the computation for fused instructions so we don't |
| // have to recompute called_instructions for the entire fusion instruction. |
| CHECK(!IsFused()); |
| set_called_computation(kScatterComputationIndex, computation); |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| Window window_; |
| }; |
| |
| class HloCustomCallInstruction : public HloInstruction { |
| public: |
| HloCustomCallInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| absl::string_view custom_call_target, |
| std::string opaque, |
| CustomCallApiVersion api_version); |
| |
| // Constructor for a custom call with constrained layout. 'shape' and |
| // 'operands_with_layout' must all have layouts. |
| HloCustomCallInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| absl::string_view custom_call_target, |
| std::string opaque, |
| absl::Span<const Shape> operand_shapes_with_layout, |
| CustomCallApiVersion api_version); |
| |
| // Constructor for a custom call with a to_apply computation. |
| HloCustomCallInstruction(const Shape& shape, |
| absl::Span<HloInstruction* const> operands, |
| HloComputation* to_apply, |
| absl::string_view custom_call_target, |
| std::string opaque, |
| CustomCallApiVersion api_version); |
| |
| // Constructor for a custom call with multiple computations. |
| HloCustomCallInstruction( |
| const Shape& shape, absl::Span<HloInstruction* const> operands, |
| absl::Span<HloComputation* const> called_computations, |
| absl::string_view custom_call_target, std::string opaque, |
| CustomCallApiVersion api_version); |
| |
| const Window& window() const override { |
| CHECK(window_ != nullptr); |
| return *window_; |
| } |
| |
| void set_window(const Window& window) override { |
| window_ = std::make_unique<Window>(window); |
| } |
| |
| const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { |
| CHECK(convolution_dimension_numbers_ != nullptr); |
| return *convolution_dimension_numbers_; |
| } |
| |
| void set_convolution_dimension_numbers( |
| const ConvolutionDimensionNumbers& dnums) { |
| convolution_dimension_numbers_ = |
| std::make_unique<ConvolutionDimensionNumbers>(dnums); |
| } |
| // TODO(jpienaar): Remove this accessor in the follow up. |
| const std::string& opaque() const { return raw_backend_config_string(); } |
| const std::string& custom_call_target() const { return custom_call_target_; } |
| void set_custom_call_target(absl::string_view target) { |
| custom_call_target_ = std::string(target); |
| } |
| void set_feature_group_count(int64_t feature_group_count) { |
| feature_group_count_ = feature_group_count; |
| } |
| void set_batch_group_count(int64_t batch_group_count) { |
| batch_group_count_ = batch_group_count; |
| } |
| // Sets whether this custom call has a side-effect - by default a custom call |
| // has no side-effects. |
| void set_custom_call_has_side_effect(bool custom_call_has_side_effect) { |
| custom_call_has_side_effect_ = custom_call_has_side_effect; |
| } |
| int64_t feature_group_count() const { return feature_group_count_; } |
| int64_t batch_group_count() const { return batch_group_count_; } |
| bool custom_call_has_side_effect() const { |
| return custom_call_has_side_effect_; |
| } |
| // Returns padding type used for ops like convolution. |
| PaddingType padding_type() const { return padding_type_; } |
| |
| void set_padding_type(PaddingType padding_type) { |
| padding_type_ = padding_type; |
| } |
| |
| // Returns the literal associated with this instruction. |
| const Literal& literal() const { return *literal_; } |
| // Set the value of literal to a new one. |
| void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); } |
| // Returns whether there is literal associated with this instruction. |
| bool HasLiteral() const { return literal_.has_value(); } |
| |
| const PrecisionConfig& precision_config() const { return precision_config_; } |
| PrecisionConfig* mutable_precision_config() { return &precision_config_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| // Returns whether the result and operand layouts are constrained. |
| bool layout_constrained() const { return layout_constrained_; } |
| |
| // Returns the shapes (with layout) of the operands. CHECKs if this custom |
| // call does not have constrained layouts. |
| const std::vector<Shape>& operand_shapes_with_layout() const { |
| CHECK(layout_constrained()); |
| return operand_shapes_with_layout_; |
| } |
| // Gets a list of output/operand buffer pairs that alias each other, where the |
| // output buffer is represented as a ShapeIndex, and the operand buffer is |
| // represented as the operand index and the ShapeIndex. By default this list |
| // is empty. |
| const std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>& |
| output_to_operand_aliasing() const { |
| return output_to_operand_aliasing_; |
| } |
| // Sets the list of output/operand buffer pairs that alias each other. |
| void set_output_to_operand_aliasing( |
| std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>> |
| aliasing) { |
| output_to_operand_aliasing_ = std::move(aliasing); |
| } |
| void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) { |
| custom_call_schedule_ = custom_call_schedule; |
| } |
| CustomCallSchedule custom_call_schedule() const { |
| return custom_call_schedule_; |
| } |
| void set_api_version(CustomCallApiVersion api_version) { |
| api_version_ = api_version; |
| } |
| CustomCallApiVersion api_version() const { return api_version_; } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| // Name of a global symbol to call. |
| std::string custom_call_target_; |
| // Describes the window in a windowed operation such as convolution. |
| std::unique_ptr<Window> window_; |
| // Describes the dimension numbers used for a convolution. |
| std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; |
| // The number of feature groups. This is used for grouped convolutions. |
| int64_t feature_group_count_; |
| int64_t batch_group_count_; |
| // Whether the result and operand layouts are constrained. |
| bool layout_constrained_; |
| // Information used to communicate to the implementation about the algorithm |
| // used to produce results for convolution instructions. |
| PrecisionConfig precision_config_; |
| // Describes the padding type for convolution instructions. |
| PaddingType padding_type_; |
| // For layout-constrained custom calls, this vector holds the shape with |
| // layout for each operand. |
| std::vector<Shape> operand_shapes_with_layout_; |
| // Whether this custom call has a side-effect. |
| bool custom_call_has_side_effect_; |
| // A list of output/operand buffer pairs that alias each other. See comment of |
| // output_to_operand_aliasing(). |
| std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>> |
| output_to_operand_aliasing_; |
| std::optional<Literal> literal_; |
| // A custom-call schedule hint. |
| CustomCallSchedule custom_call_schedule_; |
| // The version of the API used by the custom call function. |
| // TODO(b/189822916): Remove this field when all clients are migrated to the |
| // status-returning API. |
| CustomCallApiVersion api_version_; |
| }; |
| |
| class HloPadInstruction : public HloInstruction { |
| public: |
| explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, |
| HloInstruction* padding_value, |
| const PaddingConfig& padding_config); |
| // Returns the padding configuration for a pad node. |
| const PaddingConfig& padding_config() const { return padding_config_; } |
| PaddingConfig* mutable_padding_config() { return &padding_config_; } |
| // Returns the operand being padded. |
| const HloInstruction* padded_operand() const { return operand(0); } |
| HloInstruction* mutable_padded_operand() { return mutable_operand(0); } |
| // Returns the padding value. |
| const HloInstruction* padding_value() const { return operand(1); } |
| HloInstruction* mutable_padding_value() { return mutable_operand(1); } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // The padding configuration that describes the edge padding and interior |
| // padding of this pad instruction. |
| PaddingConfig padding_config_; |
| }; |
| |
| class HloDynamicIndexInstruction : public HloInstruction { |
| public: |
| explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) |
| : HloInstruction(opcode, shape) {} |
| virtual int64_t first_index_operand_number() const = 0; |
| |
| // Returns a subspan of operands which represent the start indices. |
| absl::Span<HloInstruction* const> index_operands() const { |
| return absl::MakeSpan(operands()).subspan(first_index_operand_number()); |
| } |
| |
| // Returns the shapes of the index operands. |
| std::vector<Shape> index_shapes() const { |
| std::vector<Shape> shapes; |
| auto indices = index_operands(); |
| for (const HloInstruction* index : indices) { |
| shapes.push_back(index->shape()); |
| } |
| return shapes; |
| } |
| }; |
| |
| class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { |
| public: |
| explicit HloDynamicSliceInstruction(const Shape& shape, |
| HloInstruction* operand, |
| HloInstruction* start_indices, |
| absl::Span<const int64_t> slice_sizes); |
| explicit HloDynamicSliceInstruction( |
| const Shape& shape, HloInstruction* operand, |
| absl::Span<HloInstruction* const> start_indices, |
| absl::Span<const int64_t> slice_sizes); |
| // Old methods kept for smooth subclassing transition END. |
| // Returns the size of the slice in the given dimension for a dynamic |
| // slice node. |
| int64_t slice_sizes(int64_t dimension) const { |
| return dynamic_slice_sizes_[dimension]; |
| } |
| const std::vector<int64_t>& dynamic_slice_sizes() const { |
| return dynamic_slice_sizes_; |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| int64_t first_index_operand_number() const override { return 1; } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Describes the [start, start + size) range size for a dynamic slice |
| // ('start' is specified dynamically in the second operand of the operation). |
| std::vector<int64_t> dynamic_slice_sizes_; |
| }; |
| |
| class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { |
| public: |
| explicit HloDynamicUpdateSliceInstruction(const Shape& shape, |
| HloInstruction* operand, |
| HloInstruction* update, |
| HloInstruction* start_indices); |
| explicit HloDynamicUpdateSliceInstruction( |
| const Shape& shape, HloInstruction* operand, HloInstruction* update, |
| absl::Span<HloInstruction* const> start_indices); |
| |
| int64_t first_index_operand_number() const override { return 2; } |
| }; |
| |
| class HloGatherInstruction : public HloInstruction { |
| public: |
| explicit HloGatherInstruction( |
| const Shape& shape, HloInstruction* operand, |
| HloInstruction* start_indices, |
| const GatherDimensionNumbers& gather_dim_numbers, |
| absl::Span<const int64_t> slice_sizes, bool indices_are_sorted); |
| const GatherDimensionNumbers& gather_dimension_numbers() const { |
| CHECK(gather_dimension_numbers_ != nullptr); |
| return *gather_dimension_numbers_; |
| } |
| absl::Span<const int64_t> gather_slice_sizes() const { |
| return gather_slice_sizes_; |
| } |
| bool indices_are_sorted() const { return indices_are_sorted_; } |
| void set_indices_are_sorted(bool indices_are_sorted) { |
| indices_are_sorted_ = indices_are_sorted; |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| // Creates an instance of GatherDimensionNumbers. |
| static GatherDimensionNumbers MakeGatherDimNumbers( |
| absl::Span<const int64_t> offset_dims, |
| absl::Span<const int64_t> collapsed_slice_dims, |
| absl::Span<const int64_t> start_index_map, int64_t index_vector_dim); |
| // Returns the dump string of the given gather dimension numbers. |
| static std::string GatherDimensionNumbersToString( |
| const GatherDimensionNumbers& gather_dimension_numbers); |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; |
| std::vector<int64_t> gather_slice_sizes_; |
| bool indices_are_sorted_; |
| }; |
| |
| class HloScatterInstruction : public HloInstruction { |
| public: |
| explicit HloScatterInstruction( |
| const Shape& shape, absl::Span<HloInstruction* const> args, |
| HloComputation* update_computation, |
| const ScatterDimensionNumbers& scatter_dim_numbers, |
| bool indices_are_sorted, bool unique_indices); |
| const ScatterDimensionNumbers& scatter_dimension_numbers() const { |
| CHECK(scatter_dimension_numbers_ != nullptr); |
| return *scatter_dimension_numbers_; |
| } |
| bool indices_are_sorted() const { return indices_are_sorted_; } |
| void set_indices_are_sorted(bool indices_are_sorted) { |
| indices_are_sorted_ = indices_are_sorted; |
| } |
| bool unique_indices() const override { return unique_indices_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| int64_t scatter_operand_count() const { return operand_count() / 2; } |
| absl::Span<HloInstruction* const> scatter_operands() const { |
| return absl::MakeConstSpan(operands()).first(scatter_operand_count()); |
| } |
| absl::Span<HloInstruction* const> scatter_updates() const { |
| return absl::MakeConstSpan(operands()).last(scatter_operand_count()); |
| } |
| const HloInstruction* scatter_indices() const { |
| return operand(scatter_operand_count()); |
| } |
| HloInstruction* scatter_indices() { |
| return mutable_operand(scatter_operand_count()); |
| } |
| |
| // Creates an instance of ScatterDimensionNumbers. |
| static ScatterDimensionNumbers MakeScatterDimNumbers( |
| absl::Span<const int64_t> update_window_dims, |
| absl::Span<const int64_t> inserted_window_dims, |
| absl::Span<const int64_t> scatter_dims_to_operand_dims, |
| int64_t index_vector_dim); |
| // Returns the dump string of the given scatter dimension numbers. |
| static std::string ScatterDimensionNumbersToString( |
| const ScatterDimensionNumbers& scatter_dimension_numbers); |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; |
| bool indices_are_sorted_; |
| bool unique_indices_; |
| }; |
| |
| class HloIotaInstruction : public HloInstruction { |
| public: |
| explicit HloIotaInstruction(const Shape& shape, int64_t iota_dimension); |
| |
| // Returns the dimension sizes or numbers associated with this instruction. |
| int64_t iota_dimension() const { return iota_dimension_; } |
| absl::Span<const int64_t> dimensions() const override { |
| return absl::MakeConstSpan(&iota_dimension_, 1); |
| } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t iota_dimension_; |
| }; |
| |
| class HloDotInstruction : public HloInstruction { |
| public: |
| // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch |
| // dimensions specified in 'dimension_numbers'. |
| explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, |
| HloInstruction* rhs, |
| const DotDimensionNumbers& dimension_numbers, |
| const PrecisionConfig& precision_config); |
| |
| // Returns data on the dimension numbers used for a dot operation. |
| const DotDimensionNumbers& dot_dimension_numbers() const { |
| return dot_dimension_numbers_; |
| } |
| |
| // Returns the information used to tell the implementation information about |
| // what sort of precision is requested. The meaning of the field is backend |
| // specific. At the moment, it is only supported for kConvolution and kDot. |
| // Transformations on one kDot or kConvolution to another will preserve this |
| // information. Transformations to other HLOs will not preserve this |
| // information but it is presumed that the alternate lowering is strictly |
| // superior. |
| const PrecisionConfig& precision_config() const { return precision_config_; } |
| PrecisionConfig* mutable_precision_config() { return &precision_config_; } |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| // Describes the dimension numbers used for a dot. |
| DotDimensionNumbers dot_dimension_numbers_; |
| |
| // Information used to communicate to the implementation about the algorithm |
| // used to produce results. See the documentation on precision_config(). |
| PrecisionConfig precision_config_; |
| }; |
| |
| class HloDomainInstruction : public HloInstruction { |
| public: |
| explicit HloDomainInstruction( |
| const Shape& shape, HloInstruction* operand, |
| std::unique_ptr<DomainMetadata> operand_side_metadata, |
| std::unique_ptr<DomainMetadata> user_side_metadata); |
| |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| // Retrieves the operand side metadata of a kDomain instruction. |
| const DomainMetadata& operand_side_metadata() const { |
| return *operand_side_metadata_; |
| } |
| // Retrieves the user side metadata of a kDomain instruction. |
| const DomainMetadata& user_side_metadata() const { |
| return *user_side_metadata_; |
| } |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| std::unique_ptr<DomainMetadata> operand_side_metadata_; |
| std::unique_ptr<DomainMetadata> user_side_metadata_; |
| }; |
| |
| class HloGetDimensionSizeInstruction : public HloInstruction { |
| public: |
| explicit HloGetDimensionSizeInstruction(const Shape& shape, |
| HloInstruction* operand, |
| int64_t dimension); |
| |
| // Returns the dimension sizes or numbers associated with this instruction. |
| int64_t dimension() const { return dimension_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t dimension_; |
| }; |
| |
| class HloSetDimensionSizeInstruction : public HloInstruction { |
| public: |
| explicit HloSetDimensionSizeInstruction(const Shape& shape, |
| HloInstruction* operand, |
| HloInstruction* val, |
| int64_t dimension); |
| |
| // Returns the dimension sizes or numbers associated with this instruction. |
| int64_t dimension() const { return dimension_; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t dimension_; |
| }; |
| |
| class HloRngGetAndUpdateStateInstruction : public HloInstruction { |
| public: |
| explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, |
| int64_t delta); |
| |
| // Returns the delta value. |
| int64_t delta() const { return delta_; } |
| void set_delta(int64_t delta) { delta_ = delta; } |
| // Returns a serialized representation of this instruction. |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| // Implementation for non-common logic of CloneWithNewOperands. |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| int64_t delta_; |
| }; |
| |
| class HloRngBitGeneratorInstruction : public HloInstruction { |
| public: |
| HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state, |
| RandomAlgorithm algorithm); |
| |
| RandomAlgorithm algorithm() const { return algorithm_; } |
| HloInstructionProto ToProto() const override; |
| |
| private: |
| std::vector<std::string> ExtraAttributesToStringImpl( |
| const HloPrintOptions& options) const override; |
| bool IdenticalSlowPath( |
| const HloInstruction& other, |
| const std::function<bool(const HloComputation*, const HloComputation*)>& |
| eq_computations) const override; |
| std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( |
| const Shape& shape, absl::Span<HloInstruction* const> new_operands, |
| HloCloneContext* context) const override; |
| |
| RandomAlgorithm algorithm_; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |