blob: 88e874347bd220c677779ab44708bfd93e7d4aa9 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// All HloInstruction subclasses are put in this file.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class HloBatchNormInstruction : public HloInstruction {
public:
// Returns feature_index field associated with the instruction. The index
// represents the index of the feature dimension.
int64 feature_index() const { return feature_index_; }
// Returns a epsilon value associated with the instruction. The is a small
// number added to the variance to avoid divide-by-zero error.
float epsilon() const { return epsilon_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
protected:
explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
HloInstruction* operand,
HloInstruction* scale, float epsilon,
int64 feature_index);
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// A small float number added to the variance to avoid divide-by-zero error.
float epsilon_ = 0.0f;
// An integer value representing the index of the feature dimension.
int64 feature_index_ = -1;
};
class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
public:
explicit HloBatchNormTrainingInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* scale,
HloInstruction* offset,
float epsilon, int64 feature_index);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
public:
explicit HloBatchNormInferenceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloBatchNormGradInstruction : public HloBatchNormInstruction {
public:
explicit HloBatchNormGradInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* mean, HloInstruction* variance,
HloInstruction* grad_output, float epsilon, int64 feature_index);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloFftInstruction : public HloInstruction {
public:
explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
FftType fft_type,
absl::Span<const int64> fft_length);
FftType fft_type() const { return fft_type_; }
const std::vector<int64>& fft_length() const { return fft_length_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes FFT type for an FFT instruction.
FftType fft_type_ = FftType::FFT;
// Indicates the FFT length for an FFT instruction.
std::vector<int64> fft_length_;
};
class HloCopyStartInstruction : public HloInstruction {
public:
explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
bool is_cross_program_prefetch);
bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
bool is_cross_program_prefetch_;
};
class HloCompareInstruction : public HloInstruction {
public:
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
HloInstruction* rhs,
ComparisonDirection direction,
absl::optional<Comparison::Type> type);
ComparisonDirection direction() const { return compare_.GetDirection(); }
Comparison::Type type() const { return compare_.GetType(); }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Comparison compare_;
};
class HloTriangularSolveInstruction : public HloInstruction {
public:
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
HloInstruction* b,
const TriangularSolveOptions& options);
const TriangularSolveOptions& triangular_solve_options() const {
return triangular_solve_options_;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
TriangularSolveOptions triangular_solve_options_;
};
class HloCholeskyInstruction : public HloInstruction {
public:
explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
const CholeskyOptions& options);
const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
CholeskyOptions cholesky_options_;
};
// Class that represents instructions that synchronize and transfer data between
// partitioned devices. Send/Recv and collective instructions (AllReduce,
// AllToAll, CollectivePermute) belong to this instruction type. A group of
// instructions (of the same opcode) with the same channel_id communicate during
// execution.
class HloChannelInstruction : public HloInstruction {
public:
// Returns the channel id associated with the instruction. The id is
// shared between each Send/Recv pair or a group of collective instructions
// and is globally unique to identify each channel.
absl::optional<int64> channel_id() const { return channel_id_; }
void set_channel_id(const absl::optional<int64>& channel_id);
// Whether this instruction is identical to `other` except for the values of
// channel IDs, as long as both have channel IDs or neither has a channel ID.
virtual bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
return channel_id_.has_value() == other.channel_id().has_value();
}
protected:
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
// Do not override IdenticalSlowPath(). Override
// IdenticalSlowPathIgnoringChannelIdValues() instead.
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const final;
absl::optional<int64> channel_id_;
};
class HloSendRecvInstruction : public HloChannelInstruction {
public:
// Returns whether this send/recv instruction sends data to/from the host.
bool is_host_transfer() const { return is_host_transfer_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
protected:
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
int64 channel_id, bool is_host_transfer);
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Whether this send/recv instruction sends data to/from the host.
bool is_host_transfer_;
};
class HloSendInstruction : public HloSendRecvInstruction {
public:
explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloSendDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloRecvInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloRecvDoneInstruction : public HloSendRecvInstruction {
public:
explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
class HloCollectiveInstruction : public HloChannelInstruction {
public:
const std::vector<ReplicaGroup>& replica_groups() const {
return replica_groups_;
}
// Returns true if the layout of the AllReduce is enforced by XLA client (as
// the layout set in the shape). The only reason for the client to set the
// layout is to separately compile computations that communicate with
// AllReduce. Since this field is only set `true` by the client, the compiler
// only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
// `false` for all other cases.
//
// When this is `true`, there may be communication endpoints outside the
// current compilation unit, so the compiler considers this AllReduce as
// side-effecting to disable compiler transformations. The compiler is free to
// transform unconstrained AllReduces differently across compilation units.
// It is an error for an HloModule to have a mix of constrained and
// unconstrained AllReduce instructions (checked by HloVerifier).
bool constrain_layout() const { return constrain_layout_; }
protected:
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::vector<ReplicaGroup> replica_groups_;
bool constrain_layout_;
};
class HloAllGatherInstruction : public HloCollectiveInstruction {
public:
explicit HloAllGatherInstruction(
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id, bool use_global_device_ids);
// Same as HloAllReduceInstruction::use_global_device_ids.
bool use_global_device_ids() const { return use_global_device_ids_; }
// The dimension on which data from different participants are concatenated.
int64 all_gather_dimension() const { return all_gather_dimension_; }
protected:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
HloInstructionProto ToProto() const override;
private:
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 all_gather_dimension_;
bool use_global_device_ids_;
};
class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id, bool use_global_device_ids);
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
// Returns true if the ids in the ReplicaGroup config represent a global id of
// (replica_id * partition_count + partition_id) instead of a replica id.
// This enables more flexible grouping of devices if this all-reduce is both
// cross-partition and cross-replica.
//
// For example with 2 replicas and 4 partitions,
// replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
// group[0] = (0,0), (0,1), (1,0), (1,1)
// group[1] = (0,2), (0,3), (1,2), (1,3)
// where each pair is (replica_id, partition_id).
bool use_global_device_ids() const { return use_global_device_ids_; }
protected:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
HloInstructionProto ToProto() const override;
private:
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
bool use_global_device_ids_;
};
class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
const absl::optional<int64>& channel_id,
const absl::optional<int64>& split_dimension);
// AllToAll can optionally take a split dimension, which means that this
// AllToAll takes a single (flattened) array operand and produces an array
// output (instead of taking a list of operands and producing a tuple).
//
// split_dimension specifies which dimension in the operand is split across
// devices in each replica_group, and also means the concatenated dimension
// on the output (i.e., input and the output shapes are the same).
absl::optional<int64> split_dimension() const { return split_dimension_; }
void set_split_dimension(int64 dim) { split_dimension_ = dim; }
protected:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
HloInstructionProto ToProto() const override;
private:
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
absl::optional<int64> split_dimension_;
};
class HloCollectivePermuteInstruction : public HloChannelInstruction {
public:
explicit HloCollectivePermuteInstruction(
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs,
const absl::optional<int64>& channel_id);
const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
return source_target_pairs_;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPathIgnoringChannelIdValues(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const std::vector<std::pair<int64, int64>> source_target_pairs_;
};
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloConcatenateInstruction : public HloInstruction {
public:
explicit HloConcatenateInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Accessor for the dimension in which a concatenate HLO should occur.
int64 concatenate_dimension() const { return dimensions(0); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloReduceInstruction : public HloInstruction {
public:
explicit HloReduceInstruction(const Shape& shape,
absl::Span<HloInstruction* const> args,
absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Returns the number of input arrays (and, consequentially, the number of
// init values) this reduce has.
int64 input_count() const { return operand_count() / 2; }
// Returns the input tensors to be reduced.
absl::Span<HloInstruction* const> inputs() const {
return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
absl::Span<HloInstruction* const> init_values() const {
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloSortInstruction : public HloInstruction {
public:
explicit HloSortInstruction(const Shape& shape, int64 dimension,
absl::Span<HloInstruction* const> operands,
HloComputation* compare, bool is_stable);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns the sort dimension for this instruction
int64 sort_dimension() const { return dimensions(0); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Returns the key operand to this instruction.
const HloInstruction* keys() const { return operand(0); }
HloInstruction* mutable_keys() { return mutable_operand(0); }
// Returns the number of value operands.
int64 values_count() const { return operand_count() - 1; }
bool is_stable() const { return is_stable_; }
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
bool is_stable_;
};
class HloTransposeInstruction : public HloInstruction {
public:
explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns whether this instruction does a rank-2 transposition.
bool IsRank2Transpose() const;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloBroadcastInstruction : public HloInstruction {
public:
explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
absl::Span<const int64> broadcast_dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloDynamicReshapeInstruction : public HloInstruction {
public:
explicit HloDynamicReshapeInstruction(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes);
// Returns the input dim sizes dimensions, which is operands[1:]
absl::Span<HloInstruction* const> dim_sizes() const {
return absl::MakeSpan(operands()).subspan(1, operand_count());
}
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Returns the input dim size dimension, which is operands[1+i]
HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; }
};
class HloReshapeInstruction : public HloInstruction {
public:
explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
int64 inferred_dimension);
int64 inferred_dimension() const { return inferred_dimension_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 inferred_dimension_;
};
class HloMapInstruction : public HloInstruction {
public:
explicit HloMapInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
bool IsElementwiseImpl(
const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
};
class HloSliceInstruction : public HloInstruction {
public:
explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
HloInstructionProto ToProto() const override;
// Returns the start index in the given dimension for a slice node.
int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
const std::vector<int64>& slice_starts() const { return slice_starts_; }
std::vector<int64>* mutable_slice_starts() { return &slice_starts_; }
// Returns the (exclusive) limit index in the given dimension for a slice
// node.
int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
const std::vector<int64>& slice_limits() const { return slice_limits_; }
std::vector<int64>* mutable_slice_limits() { return &slice_limits_; }
// Returns the stride in the given dimension for a slice node.
int64 slice_strides(int64 dimension) const {
return slice_strides_[dimension];
}
const std::vector<int64>& slice_strides() const { return slice_strides_; }
std::vector<int64>* mutable_slice_strides() { return &slice_strides_; }
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
std::vector<int64> slice_strides_;
};
class HloConstantInstruction : public HloInstruction {
public:
explicit HloConstantInstruction(Literal literal);
explicit HloConstantInstruction(Literal literal, const Shape& shape);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns the (mutable) literal associated with this instruction.
Literal* mutable_literal() { return &literal_.value(); }
// Returns whether there is literal associated with this instruction.
bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Change the layout for an Constant Hlo instruction to match new_layout. For
// tuple shaped constants shape_index is the path to the internal array
// subshape whose layout needs to be changed.
void RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index = {});
private:
bool IsElementwiseImpl(
const absl::optional<int64>& operand_idx) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
string OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
public:
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
HloInstruction* fused_root);
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation);
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Adds a new operand the fusion instruction.
HloInstruction* AddFusionOperand(HloInstruction* new_operand);
// Merges the fused instructions from 'instruction_to_merge' into the
// fused instruction set of 'this', updating operands as necessary.
//
// Precondition: 'instruction_to_merge' must be an operand of 'this'.
void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
// Merges the fused instructions from instruction_to_merge into the fused
// instruction set of 'this' and generates multioutput fusion instructions.
// All the users of instruction_to_merge will be redirected to 'this'
// instruction. instruction_to_merge will be removed from its parent
// computation.
void MergeFusionInstructionIntoMultiOutput(
HloFusionInstruction* instruction_to_merge);
// Fuses the given instruction in this fusion instruction. instruction_to_fuse
// is cloned and the clone is placed in the fusion
// instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
// than moved to cleanly handle the case where the instruction has a use
// outside the fusion instruction. Moving such an instruction into a fusion
// instruction would violate the single-result invariant of HLO instructions
// and significantly complicate code generation.
HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
return FuseInstructionInternal(instruction_to_fuse);
}
// Fuses the given instruction in this fusion instruction and generates a
// multioutput fusion instruction. A clone of the instruction_to_fuse will
// be part of the output of fusion instructions. The users of
// instruction_to_fuse will be redirected to this fusion instructions.
// instruction_to_fuse is unchanged otherwise.
HloInstruction* FuseInstructionIntoMultiOutput(
HloInstruction* instruction_to_fuse) {
return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
}
// Returns the computation for this fused instruction.
HloComputation* fused_instructions_computation() const;
// Returns the root instruction of the fused expression contained within this
// fusion instruction.
HloInstruction* fused_expression_root() const;
// Returns the list of fused instructions inside this fusion instruction. The
// returned type is a range of HloInstruction*s.
const tensorflow::gtl::iterator_range<UnwrappingIterator<
std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const;
const tensorflow::gtl::iterator_range<
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions();
// Gets the number of instructions inside this fusion instruction.
int64 fused_instruction_count() const;
// Returns the fused parameter instruction in this fusion instruction
// corresponding to the given parameter number.
HloInstruction* fused_parameter(int64 parameter_number) const;
// Returns the vector of fused parameters inside this fusion instruction.
const std::vector<HloInstruction*>& fused_parameters() const;
// Returns true if this instruction is a fusion instruction that generates
// multiple outputs.
const bool IsMultiOutputFusion() const {
return fused_expression_root()->opcode() == HloOpcode::kTuple;
}
FusionKind fusion_kind() const { return fusion_kind_; }
void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
// If multiple operands are the same instruction, keeps only one of them.
Status DeduplicateFusionOperands();
private:
// Fuses the given instruction into this fusion instruction.
// instruction_to_fuse is cloned and the clone is placed in the fusion
// instruction. The users of instruction_to_fuse will be redirected to this
// fusion instruction. instruction_to_fuse is unchanged otherwise. When
// add_output is true, a clone of the instruction_to_fuse will be added as
// additional output resulting in a multi-output fusion.
HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
bool add_output = false);
// Clones the given instruction_to_fuse and insert the clone into this fusion
// instruction. If add_output is true, a clone of instruction_to_fuse will
// be in the output of the this fusion instruction (part of the tuple of the
// fusion root).
HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
bool add_output = false);
bool IsElementwiseImpl(
const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
uint64 InnerHash() const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The type of the fusion. Used by kFusion only.
FusionKind fusion_kind_;
};
class HloRngInstruction : public HloInstruction {
public:
explicit HloRngInstruction(const Shape& shape,
RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters);
// Returns the random distribution for this rng node.
RandomDistribution random_distribution() const { return distribution_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
bool IsElementwiseImpl(
const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The distribution requested for random number generation.
RandomDistribution distribution_;
};
class HloParameterInstruction : public HloInstruction {
public:
explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
const string& name);
int64 parameter_number() const { return parameter_number_; }
// Sets and gets the whether all replicas will receive the same parameter data
// for each leaf buffer in data parallelism.
void set_parameter_replicated_at_leaf_buffers(
absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
parameter_replicated_at_leaf_buffers.size());
parameter_replicated_at_leaf_buffers_.emplace(
parameter_replicated_at_leaf_buffers.begin(),
parameter_replicated_at_leaf_buffers.end());
}
void set_parameter_replicated_at_leaf_buffers(
const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
parameter_replicated_at_leaf_buffers.size());
parameter_replicated_at_leaf_buffers_ =
parameter_replicated_at_leaf_buffers;
}
const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const {
return parameter_replicated_at_leaf_buffers_;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
string OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 parameter_number_ = 0;
// Specifies whether each buffer has the same parameter value on all replicas
// in data parallelism.
absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
};
class HloGetTupleElementInstruction : public HloInstruction {
public:
explicit HloGetTupleElementInstruction(const Shape& shape,
HloInstruction* operand, int64 index);
// Returns the tuple index associated with this instruction.
int64 tuple_index() const { return tuple_index_; }
// Sets the tuple index associated with this instruction.
void set_tuple_index(int64 new_tuple_index) {
tuple_index_ = new_tuple_index;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 tuple_index_ = -1;
};
class HloReducePrecisionInstruction : public HloInstruction {
public:
explicit HloReducePrecisionInstruction(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits);
// Returns the number of exponent bits for a reduce-precision node.
int32 exponent_bits() const { return exponent_bits_; }
// Returns the number of mantissa bits for a reduce-precision node.
int32 mantissa_bits() const { return mantissa_bits_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_ = 0;
int32 mantissa_bits_ = 0;
};
class HloInfeedInstruction : public HloInstruction {
public:
explicit HloInfeedInstruction(const Shape& infeed_shape,
HloInstruction* token_operand,
const string& config);
// Returns the infeed configuration string. The infeed configuration includes
// any metadata needed for the backend compiler (e.g., infeed buffer address)
// and is target-dependent.
string infeed_config() const { return infeed_config_; }
void set_infeed_config(const string& config) { infeed_config_ = config; }
// Returns the shape of the data received by the infeed. This is not the same
// as the shape of the infeed instruction which produces a tuple containing
// the infeed data shape and a TOKEN.
const Shape& infeed_shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
return ShapeUtil::GetSubshape(shape(), {0});
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the infeed configuration.
string infeed_config_;
};
class HloOutfeedInstruction : public HloInstruction {
public:
explicit HloOutfeedInstruction(const Shape& outfeed_shape,
HloInstruction* operand,
HloInstruction* token_operand,
absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const { return outfeed_shape_; }
// Returns the mutable shape for the Outfeed instruction.
Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
// Returns the config for the Outfeed instruction.
const string& outfeed_config() const { return outfeed_config_; }
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Shape of outfeed request.
Shape outfeed_shape_;
// Outfeed configuration information, only present for kOutfeed.
string outfeed_config_;
};
class HloConvolutionInstruction : public HloInstruction {
public:
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
int64 feature_group_count, int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
return convolution_dimension_numbers_;
}
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ = dnums;
}
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count() const { return feature_group_count_; }
void set_feature_group_count(int64 num_feature_groups) {
feature_group_count_ = num_feature_groups;
}
// The number of batch groups. Must be a divisor of the input batch dimension.
int64 batch_group_count() const { return batch_group_count_; }
void set_batch_group_count(int64 num_batch_groups) {
batch_group_count_ = num_batch_groups;
}
// Returns the information used to tell the implementation information about
// what sort of precision is requested. The meaning of the field is backend
// specific. At the moment, it is only supported for kConvolution and kDot.
// Transformations on one kDot or kConvolution to another will preserve this
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
const PrecisionConfig& precision_config() const { return precision_config_; }
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count_;
// The number of batch groups. Must be a divisor of the input batch dimension.
int64 batch_group_count_;
// Describes the window used for a convolution.
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
// Information used to communicate to the implementation about the algorithm
// used to produce results. See the documentation on precision_config().
PrecisionConfig precision_config_;
};
class HloReduceWindowInstruction : public HloInstruction {
public:
explicit HloReduceWindowInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* init_value,
const Window& window,
HloComputation* reduce_computation);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
class HloSelectAndScatterInstruction : public HloInstruction {
public:
explicit HloSelectAndScatterInstruction(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
// Gets/sets the select or scatter HloComputation for SelectAndScatter. The
// setters should only be called by HloModule or HloComputation methods.
HloComputation* select() const {
return called_computations()[kSelectComputationIndex];
}
HloComputation* scatter() const {
return called_computations()[kScatterComputationIndex];
}
void set_select(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
set_called_computation(kSelectComputationIndex, computation);
}
void set_scatter(HloComputation* computation) {
// Don't allow changing the computation for fused instructions so we don't
// have to recompute called_instructions for the entire fusion instruction.
CHECK(!IsFused());
set_called_computation(kScatterComputationIndex, computation);
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
class HloCustomCallInstruction : public HloInstruction {
public:
HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque);
// Constructor for a custom call with constrained layout. 'shape' and
// 'operands_with_layout' must all have layouts.
HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque,
absl::Span<const Shape> operand_shapes_with_layout);
// Constructor for a custom call with a to_apply computation.
HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
HloComputation* to_apply,
absl::string_view custom_call_target, string opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
}
void set_window(const Window& window) override {
window_ = absl::make_unique<Window>(window);
}
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
CHECK(convolution_dimension_numbers_ != nullptr);
return *convolution_dimension_numbers_;
}
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ =
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
// TODO(jpienaar): Remove this accessor in the follow up.
const string& opaque() const { return raw_backend_config_string(); }
const string& custom_call_target() const { return custom_call_target_; }
void set_feature_group_count(int64 feature_group_count) {
feature_group_count_ = feature_group_count;
}
void set_batch_group_count(int64 batch_group_count) {
batch_group_count_ = batch_group_count;
}
// Sets whether this custom call has a side-effect - by default a custom call
// has no side-effects.
void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
custom_call_has_side_effect_ = custom_call_has_side_effect;
}
int64 feature_group_count() const { return feature_group_count_; }
int64 batch_group_count() const { return batch_group_count_; }
bool custom_call_has_side_effect() const {
return custom_call_has_side_effect_;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Returns whether the result and operand layouts are constrained.
bool layout_constrained() const { return layout_constrained_; }
// Returns the shapes (with layout) of the operands. CHECKs if this custom
// call does not have constrained layouts.
const std::vector<Shape>& operand_shapes_with_layout() const {
CHECK(layout_constrained());
return operand_shapes_with_layout_;
}
// Gets a list of output/operand buffer pairs that alias each other, where the
// output buffer is represented as a ShapeIndex, and the operand buffer is
// represented as the operand index and the ShapeIndex. By default this list
// is empty.
const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
output_to_operand_aliasing() const {
return output_to_operand_aliasing_;
}
// Sets the list of output/operand buffer pairs that alias each other.
void set_output_to_operand_aliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
aliasing) {
output_to_operand_aliasing_ = std::move(aliasing);
}
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Name of a global symbol to call.
string custom_call_target_;
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
// The number of feature groups. This is used for grouped convolutions.
int64 feature_group_count_;
int64 batch_group_count_;
// Whether the result and operand layouts are constrained.
bool layout_constrained_;
// For layout-constrained custom calls, this vector holds the shape with
// layout for each operand.
std::vector<Shape> operand_shapes_with_layout_;
// Whether this custom call has a side-effect.
bool custom_call_has_side_effect_;
// A list of output/operand buffer pairs that alias each other. See comment of
// output_to_operand_aliasing().
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing_;
};
class HloPadInstruction : public HloInstruction {
public:
explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
HloInstruction* padding_value,
const PaddingConfig& padding_config);
// Returns the padding configuration for a pad node.
const PaddingConfig& padding_config() const { return padding_config_; }
PaddingConfig* mutable_padding_config() { return &padding_config_; }
// Returns the padding value.
const HloInstruction* padding_value() const { return operand(1); }
HloInstruction* mutable_padding_value() { return mutable_operand(1); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The padding configuration that describes the edge padding and interior
// padding of this pad instruction.
PaddingConfig padding_config_;
};
class HloDynamicIndexInstruction : public HloInstruction {
public:
explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
: HloInstruction(opcode, shape) {}
virtual int64 first_index_operand_number() const = 0;
// Returns a subspan of operands which represent the start indices.
absl::Span<HloInstruction* const> index_operands() const {
return absl::MakeSpan(operands()).subspan(first_index_operand_number());
}
// Returns the shapes of the index operands.
std::vector<Shape> index_shapes() const {
std::vector<Shape> shapes;
auto indices = index_operands();
for (const HloInstruction* index : indices) {
shapes.push_back(index->shape());
}
return shapes;
}
};
class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
public:
explicit HloDynamicSliceInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* start_indices,
absl::Span<const int64> slice_sizes);
explicit HloDynamicSliceInstruction(
const Shape& shape, HloInstruction* operand,
absl::Span<HloInstruction* const> start_indices,
absl::Span<const int64> slice_sizes);
// Old methods kept for smooth subclassing transition END.
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
int64 slice_sizes(int64 dimension) const {
return dynamic_slice_sizes_[dimension];
}
const std::vector<int64>& dynamic_slice_sizes() const {
return dynamic_slice_sizes_;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
int64 first_index_operand_number() const override { return 1; }
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [start, start + size) range size for a dynamic slice
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
};
class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
public:
explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices);
explicit HloDynamicUpdateSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* update,
absl::Span<HloInstruction* const> start_indices);
int64 first_index_operand_number() const override { return 2; }
};
class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
absl::Span<const int64> slice_sizes, bool indices_are_sorted);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
absl::Span<const int64> gather_slice_sizes() const {
return gather_slice_sizes_;
}
bool indices_are_sorted() const { return indices_are_sorted_; }
void set_indices_are_sorted(bool indices_are_sorted) {
indices_are_sorted_ = indices_are_sorted;
}
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
absl::Span<const int64> offset_dims,
absl::Span<const int64> collapsed_slice_dims,
absl::Span<const int64> start_index_map, int64 index_vector_dim);
// Returns the dump string of the given gather dimension numbers.
static string GatherDimensionNumbersToString(
const GatherDimensionNumbers& gather_dimension_numbers);
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_slice_sizes_;
bool indices_are_sorted_;
};
class HloScatterInstruction : public HloInstruction {
public:
explicit HloScatterInstruction(
const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers,
bool indices_are_sorted, bool unique_indices);
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
CHECK(scatter_dimension_numbers_ != nullptr);
return *scatter_dimension_numbers_;
}
bool indices_are_sorted() const { return indices_are_sorted_; }
void set_indices_are_sorted(bool indices_are_sorted) {
indices_are_sorted_ = indices_are_sorted;
}
bool unique_indices() const override { return unique_indices_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Creates an instance of ScatterDimensionNumbers.
static ScatterDimensionNumbers MakeScatterDimNumbers(
absl::Span<const int64> update_window_dims,
absl::Span<const int64> inserted_window_dims,
absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim);
// Returns the dump string of the given scatter dimension numbers.
static string ScatterDimensionNumbersToString(
const ScatterDimensionNumbers& scatter_dimension_numbers);
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
bool indices_are_sorted_;
bool unique_indices_;
};
class HloIotaInstruction : public HloInstruction {
public:
explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
// Returns the dimension sizes or numbers associated with this instruction.
int64 iota_dimension() const { return iota_dimension_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
const int64 iota_dimension_;
};
class HloDotInstruction : public HloInstruction {
public:
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
// Returns data on the dimension numbers used for a dot operation.
const DotDimensionNumbers& dot_dimension_numbers() const {
return dot_dimension_numbers_;
}
// Returns the information used to tell the implementation information about
// what sort of precision is requested. The meaning of the field is backend
// specific. At the moment, it is only supported for kConvolution and kDot.
// Transformations on one kDot or kConvolution to another will preserve this
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
const PrecisionConfig& precision_config() const { return precision_config_; }
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
// Describes the dimension numbers used for a dot.
DotDimensionNumbers dot_dimension_numbers_;
// Information used to communicate to the implementation about the algorithm
// used to produce results. See the documentation on precision_config().
PrecisionConfig precision_config_;
};
class HloDomainInstruction : public HloInstruction {
public:
explicit HloDomainInstruction(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata);
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Retrieves the operand side metadata of a kDomain instruction.
const DomainMetadata& operand_side_metadata() const {
return *operand_side_metadata_;
}
// Retrieves the user side metadata of a kDomain instruction.
const DomainMetadata& user_side_metadata() const {
return *user_side_metadata_;
}
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<DomainMetadata> operand_side_metadata_;
std::unique_ptr<DomainMetadata> user_side_metadata_;
};
class HloGetDimensionSizeInstruction : public HloInstruction {
public:
explicit HloGetDimensionSizeInstruction(const Shape& shape,
HloInstruction* operand,
int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
int64 dimension() const { return dimension_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 dimension_;
};
class HloSetDimensionSizeInstruction : public HloInstruction {
public:
explicit HloSetDimensionSizeInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* val, int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
int64 dimension() const { return dimension_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 dimension_;
};
class HloRngGetAndUpdateStateInstruction : public HloInstruction {
public:
explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta);
// Returns the delta value.
int64 delta() const { return delta_; }
void set_delta(int64 delta) { delta_ = delta; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 delta_;
};
class HloRngBitGeneratorInstruction : public HloInstruction {
public:
HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
RandomAlgorithm algorithm);
RandomAlgorithm algorithm() const { return algorithm_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
RandomAlgorithm algorithm_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_