| #pragma once |
| |
| #include <ir_all_nodes.h> |
| #include <ir_base_nodes.h> |
| #include <parallel_type_bitmap.h> |
| #include <type.h> |
| #include <utils.h> |
| |
| #include <c10/macros/Export.h> |
| #include <c10/util/Optional.h> |
| |
| #include <cstdint> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| namespace torch { |
| namespace jit { |
| namespace fuser { |
| namespace cuda { |
| |
| class IrBuilderPasskey; |
| |
| // Abstract nodes |
| class Val; |
| class Expr; |
| |
| // Values |
| class Bool; |
| class Double; |
| class Int; |
| class NamedScalar; |
| |
| class IterDomain; |
| class TensorDomain; |
| class TensorView; |
| |
| // Expressions |
| class UnaryOp; |
| class BinaryOp; |
| class TernaryOp; |
| class RNGOp; |
| class ReductionOp; |
| class WelfordOp; |
| class BroadcastOp; |
| |
| namespace kir { |
| class Kernel; |
| |
| // Values |
| class Predicate; |
| class TensorIndex; |
| |
| // Expressions |
| class Allocate; |
| class BlockSync; |
| class GridSync; |
| class CpAsyncWait; |
| class CpAsyncCommit; |
| class InitMagicZero; |
| class UpdateMagicZero; |
| class ForLoop; |
| class IfThenElse; |
| class GridReduction; |
| class GroupedGridReduction; |
| class GridBroadcast; |
| class GridWelford; |
| class GroupedGridWelford; |
| class AllocateFusedReduction; |
| |
| // Expr container |
| class Scope; |
| |
| class TORCH_CUDA_CU_API Predicate final : public Val { |
| public: |
| explicit Predicate( |
| IrBuilderPasskey passkey, |
| PredicateType ptype, |
| const Expr* expr = nullptr, |
| Bool* thread_pred = nullptr); |
| |
| explicit Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop); |
| |
| explicit Predicate(IrBuilderPasskey passkey, Bool* value); |
| |
| PredicateType predicate_type() const { |
| return ptype_; |
| } |
| |
| const Expr* expr() const { |
| TORCH_INTERNAL_ASSERT( |
| ptype_ != PredicateType::Unswitch && |
| ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual); |
| return expr_; |
| } |
| |
| Bool* thread_pred() const { |
| TORCH_INTERNAL_ASSERT( |
| ptype_ == PredicateType::Inline || |
| ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift || |
| ptype_ == PredicateType::Padding || |
| ptype_ == PredicateType::ReductionWrite); |
| return thread_pred_; |
| } |
| |
| ForLoop* unrolled_loop() const { |
| TORCH_INTERNAL_ASSERT(ptype_ == PredicateType::Unswitch); |
| return unrolled_loop_; |
| } |
| |
| bool hasValue() const { |
| return value_ != nullptr; |
| } |
| |
| Bool* value() const { |
| TORCH_INTERNAL_ASSERT( |
| value_ != nullptr, |
| "The conditional expression for this Predicate is invalid."); |
| return value_; |
| } |
| |
| void setValue(Bool* value) { |
| TORCH_INTERNAL_ASSERT(value != nullptr, "The Bool expression is invalid."); |
| value_ = value; |
| } |
| |
| bool isConst() const final { |
| return hasValue() && value_->isConst(); |
| } |
| |
| private: |
| PredicateType ptype_ = PredicateType::Manual; |
| |
| // For PredicateCompute::getInlinePredicate, |
| // ShiftPredicateInserter::getShiftPredicate and getPaddingPredicate |
| const Expr* expr_ = nullptr; |
| |
| // For PredicateCompute::getInlinePredicate |
| Bool* thread_pred_ = nullptr; |
| |
| // For ParallelType::Unswitch - UnswitchPredicate::get |
| ForLoop* unrolled_loop_ = nullptr; |
| |
| // The Bool conditional value |
| // The value is nullptr until lower_predicate pass |
| Bool* value_ = nullptr; |
| }; |
| |
| class TORCH_CUDA_CU_API TensorIndex final : public Val { |
| public: |
| TensorIndex( |
| IrBuilderPasskey, |
| const TensorView* view, |
| std::vector<Val*> indices); |
| |
| std::vector<Val*>::size_type nDims() const { |
| return indices_.size(); |
| } |
| |
| Val* index(int i) const; |
| |
| const std::vector<Val*>& indices() const { |
| return indices_; |
| } |
| |
| TensorView* view() const { |
| TORCH_INTERNAL_ASSERT(view_ != nullptr); |
| return const_cast<TensorView*>(view_); // NOLINT |
| } |
| |
| private: |
| const TensorView* view_ = nullptr; |
| std::vector<Val*> indices_; |
| }; |
| |
| //! Allocate is a lower level Node that describes a buffer of memory that |
| //! is required as an intermediate within a kernel. The extent is the expression |
| //! of the size of the buffer that is generated from the TensorView that |
| //! describes the output of an operation. |
| class TORCH_CUDA_CU_API Allocate final : public Expr { |
| public: |
| //! Allocation of a multi-dimensional buffer |
| //! |
| //! param shape Size of each dimension |
| explicit Allocate( |
| IrBuilderPasskey passkey, |
| Val* buffer, |
| MemoryType memory_type, |
| std::vector<Val*> shape = {}, |
| bool zero_init = false); |
| |
| //! Allocation of a non-dimensional buffer |
| //! |
| //! param size Size of allocation |
| explicit Allocate( |
| IrBuilderPasskey passkey, |
| Val* buffer, |
| MemoryType memory_type, |
| Val* size, |
| bool zero_init = false); |
| |
| Expr* shallowCopy() const override; |
| |
| Val* buffer() const { |
| return buffer_; |
| } |
| |
| MemoryType memoryType() const { |
| return memory_type_; |
| } |
| |
| Val* size() const { |
| return size_; |
| } |
| |
| const std::vector<Val*>& shape() const { |
| return shape_; |
| } |
| |
| bool zeroInit() const { |
| return zero_init_; |
| } |
| |
| const Allocate* alias() const { |
| return alias_; |
| } |
| |
| void setAlias(const Allocate* alias) { |
| TORCH_INTERNAL_ASSERT(alias != this); |
| TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_); |
| alias_ = alias; |
| } |
| |
| private: |
| Val* buffer_ = nullptr; |
| MemoryType memory_type_ = MemoryType::Local; |
| //! Size of each dimension |
| std::vector<Val*> shape_; |
| bool zero_init_ = false; |
| //! Total size |
| Val* size_ = nullptr; |
| |
| // This alias tracks the next Allocate node in a linked chain of aliases |
| // If the alias is nullptr, then the Allocate node uses memory in the kernel |
| const Allocate* alias_ = nullptr; |
| }; |
| |
| // Sync represents __syncthreads barrier for block level coordination. |
| // |
| // TODO(kir): change name to SyncThreads as we could have other barriers. |
| // |
| class TORCH_CUDA_CU_API BlockSync final : public Expr { |
| public: |
| explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); |
| |
| Expr* shallowCopy() const override; |
| |
| bool isWarHazardSync() const { |
| return war_sync_; |
| } |
| |
| private: |
| // TODO: war_sync_ is only used for testing/validation purposes. |
| bool war_sync_ = false; |
| }; |
| |
| // CpAsyncWait represents wait intrinsics for cp.async |
| class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { |
| public: |
| explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); |
| |
| Expr* shallowCopy() const override; |
| |
| //! Returns the remaining number of stages that are not synchronized |
| //! after this op. |
| unsigned int keepStages() const { |
| return keep_stages_; |
| } |
| |
| private: |
| //! Number of stage to leave un-sync'ed by this op. |
| unsigned int keep_stages_ = 0; |
| }; |
| |
| // CpAsyncCommit represents commit intrinsics for cp.async |
| // A commit intrinsic communicates delimiter of transaction groups |
| // to the async load hardware. Example usage see [Cicular buffer]. |
| class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { |
| public: |
| explicit CpAsyncCommit(IrBuilderPasskey passkey); |
| |
| Expr* shallowCopy() const override; |
| }; |
| |
| // Synchronize all blocks in device, implies cooperative group launch is |
| // required. |
| class TORCH_CUDA_CU_API GridSync final : public Expr { |
| public: |
| explicit GridSync( |
| IrBuilderPasskey passkey, |
| ParallelTypeBitmap sync_dims, |
| Val* sync_buffer); |
| |
| Expr* shallowCopy() const override; |
| |
| ParallelTypeBitmap syncDims() const { |
| return sync_dims_; |
| } |
| |
| Val* syncBuffer() const { |
| return sync_buffer_; |
| } |
| |
| private: |
| ParallelTypeBitmap sync_dims_; |
| Val* sync_buffer_ = nullptr; |
| }; |
| |
| // Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero |
| // in helpers.cu |
| class TORCH_CUDA_CU_API InitMagicZero final : public Expr { |
| public: |
| explicit InitMagicZero(IrBuilderPasskey passkey); |
| |
| Expr* shallowCopy() const override; |
| }; |
| |
| // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero |
| // in helpers.cu |
| class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { |
| public: |
| explicit UpdateMagicZero(IrBuilderPasskey passkey); |
| |
| Expr* shallowCopy() const override; |
| }; |
| |
| // TODO(kir): promote to IR node |
| class TORCH_CUDA_CU_API Scope { |
| public: |
| explicit Scope(Expr* owner) : owner_(owner) {} |
| |
| const std::vector<Expr*>& exprs() const { |
| return exprs_; |
| } |
| |
| bool empty() const { |
| return exprs_.empty(); |
| } |
| |
| auto size() const { |
| return exprs_.size(); |
| } |
| |
| auto& operator[](size_t i) { |
| return exprs_[i]; |
| } |
| |
| auto& operator[](size_t i) const { |
| return exprs_[i]; |
| } |
| |
| // Insert expr before expression at pos |
| void insert(size_t pos, Expr* expr); |
| |
| // Insert expr before ref |
| void insert_before(Expr* ref, Expr* expr); |
| |
| // Insert expr after ref |
| void insert_after(Expr* ref, Expr* expr); |
| |
| void push_back(Expr* e) { |
| exprs_.push_back(e); |
| } |
| |
| // Erase expr at pos |
| void erase(size_t pos); |
| |
| // Erase expr ref |
| void erase(Expr* ref); |
| |
| bool contains(Expr* expr) const; |
| |
| void clear(); |
| |
| Expr* owner() const { |
| return owner_; |
| } |
| |
| private: |
| // Insert expr before pos |
| void insert(std::vector<Expr*>::const_iterator pos, Expr* expr); |
| |
| // Erase expr at pos |
| void erase(std::vector<Expr*>::const_iterator pos); |
| |
| private: |
| std::vector<Expr*> exprs_; |
| |
| //! Owner exprssion of this scope, e.g., IfThenElse |
| Expr* owner_ = nullptr; |
| }; |
| |
| //! ForLoop provides scoping around an int iterator from 0 to range. Exprs |
| //! placed in its body are considered inside the scope of the for loop. In the |
| //! future the implementation should look quite different so that we can do |
| //! proper dependency annalysis like in Fusion. |
| //! |
| //! TODO(kir): this is not a real expression |
| //! |
| //! ForLoop may represent a part of an iteration domain representend |
| //! by iter_domain_. In that case, the loop extent field, extent_, may |
| //! be smaller than the extent of iter_domain_. |
| class TORCH_CUDA_CU_API ForLoop final : public Expr { |
| public: |
| //! By default, start and stop are the same as those of iter_domain. |
| //! Step is one by default. |
| //! |
| //! TODO: cleaner way to set options? |
| ForLoop( |
| IrBuilderPasskey passkey, |
| IterDomain* iter_domain, |
| Val* index, |
| Val* start, |
| Val* stop, |
| Val* step, |
| bool vectorize, |
| Val* vectorize_shift, |
| bool unroll_required, |
| DoubleBufferLoopStage double_buffer_loop_stage); |
| |
| ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain); |
| |
| ForLoop(IrBuilderPasskey passkey, const ForLoop* other); |
| |
| Expr* shallowCopy() const override; |
| |
| Val* index() const { |
| return index_; |
| } |
| |
| Val* start() const; |
| |
| Val* stop() const; |
| |
| Val* step() const; |
| |
| Val* vectorize_shift() const { |
| return vectorize_shift_; |
| } |
| |
| IterDomain* iter_domain() const { |
| return iter_domain_; |
| } |
| |
| // TODO: Return pointer instead of reference to be more consistent |
| Scope& body() { |
| return body_; |
| } |
| |
| const Scope& body() const { |
| return body_; |
| } |
| |
| bool vectorize() const { |
| return vectorize_; |
| } |
| |
| //! True if unrolled (i.e., "#pragma unroll" is attached) |
| bool isUnrolled() const; |
| |
| //! True if unrolling is required |
| bool isUnrollRequired() const { |
| return unroll_required_; |
| } |
| |
| //! Set unrolling required |
| void requireUnroll() { |
| unroll_required_ = true; |
| } |
| |
| //! True if no actual for-loop is materialized |
| bool isTrivial() const; |
| |
| //! Returns the stage of a double buffered iterdomain |
| //! that this for loop materializes. |
| auto doubleBufferLoopStage() const { |
| return double_buffer_loop_stage_; |
| } |
| |
| private: |
| //! Returns if a loop could be unrolled. |
| bool isUnrollable() const; |
| |
| private: |
| IterDomain* const iter_domain_ = nullptr; |
| |
| Val* index_ = nullptr; |
| Val* start_ = nullptr; |
| Val* stop_ = nullptr; |
| Val* step_ = nullptr; |
| |
| // vectorize is true when the for-loop contains a vectorize set |
| // the flag is used to omit the for-loop from the kernel |
| bool vectorize_ = false; |
| // [pre | vectorize | post] <= inner-most, merged root domain |
| // shift_ is applied to vectorize and post sections. |
| Val* vectorize_shift_ = nullptr; |
| |
| //! True if unroll is required for avoiding stack allocation |
| bool unroll_required_ = false; |
| |
| Scope body_; |
| |
| //! Tracks if this for loop is implementing a stage of |
| //! a double buffered iterdomain. |
| DoubleBufferLoopStage double_buffer_loop_stage_ = |
| DoubleBufferLoopStage::NotApplicable; |
| }; |
| |
| //! IfThenElse provides scoping for an boolean operator. Exprs placed in its |
| //! body are considered inside the scope of the if statement. In the future the |
| //! implementation should look quite different so that we can do proper |
| //! dependency annalysis like in Fusion. |
| //! |
| //! TODO(kir): this is not a real expression |
| //! |
| class TORCH_CUDA_CU_API IfThenElse final : public Expr { |
| public: |
| explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); |
| |
| Expr* shallowCopy() const override; |
| |
| Scope& thenBody() { |
| return then_body_; |
| } |
| const Scope& thenBody() const { |
| return then_body_; |
| } |
| |
| Scope& elseBody() { |
| return else_body_; |
| } |
| |
| const Scope& elseBody() const { |
| return else_body_; |
| } |
| |
| bool hasElse() const { |
| return !else_body_.empty(); |
| } |
| |
| private: |
| Scope then_body_; |
| Scope else_body_; |
| }; |
| |
| //! Grid reduction operation |
| //! |
| //! This node is used only after lowering a fusion to explicitly mark a grid |
| //! reduction and the buffer allocation needed to do it. |
| //! |
| //! This node provides FusionExecutor the information it needs to allocate the |
| //! reduction and sync buffers. |
| class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { |
| public: |
| GridReduction( |
| IrBuilderPasskey passkey, |
| BinaryOpType reduction_op_type, |
| Val* init, |
| Val* out, |
| Val* in, |
| Allocate* reduction_buffer, |
| Allocate* sync_buffer, |
| Val* entrance_index, |
| Val* entrances, |
| bool is_allreduce = false); |
| |
| Expr* shallowCopy() const override; |
| |
| Allocate* reduction_buffer() const { |
| return reduction_buffer_; |
| } |
| |
| Allocate* sync_buffer() const { |
| return sync_buffer_; |
| } |
| |
| // Which instance of entering this grid reduction is this iteration? |
| Val* entrance_index() const { |
| return entrance_index_; |
| } |
| |
| // How many times will this grid reduction be entered |
| Val* entrances() const { |
| return entrances_; |
| } |
| |
| const ParallelTypeBitmap& threadPredicate() const { |
| return thread_predicate_; |
| } |
| |
| GridReduction* withThreadPredicate( |
| const ParallelTypeBitmap& thread_predicate) { |
| auto result = shallowCopy()->as<GridReduction>(); |
| result->thread_predicate_ = thread_predicate; |
| return result; |
| } |
| |
| private: |
| Allocate* reduction_buffer_ = nullptr; |
| Allocate* sync_buffer_ = nullptr; |
| // gridReduce has template flags for thread predicates. In order to |
| // use them, the thread predicate is held here separately from |
| // Expr::predicate_. |
| ParallelTypeBitmap thread_predicate_; |
| Val* entrance_index_ = nullptr; |
| Val* entrances_ = nullptr; |
| }; |
| |
| class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { |
| public: |
| GroupedGridReduction( |
| IrBuilderPasskey passkey, |
| std::vector<BinaryOpType> reduction_op_type, |
| std::vector<Val*> init, |
| std::vector<Val*> out, |
| std::vector<Val*> in, |
| std::vector<Allocate*> reduction_buffers, |
| Allocate* sync_buffer, |
| Val* entrance_index, |
| Val* entrances, |
| Val* buffer_stride, |
| bool is_allreduce = false); |
| |
| Expr* shallowCopy() const override; |
| |
| const std::vector<Allocate*>& reduction_buffers() const { |
| return reduction_buffers_; |
| } |
| |
| Allocate* reduction_buffer(size_t i) const { |
| return reduction_buffers_.at(i); |
| } |
| |
| Allocate* sync_buffer() const { |
| return sync_buffer_; |
| } |
| |
| // Which instance of entering this grid reduction is this iteration? |
| Val* entrance_index() const { |
| return entrance_index_; |
| } |
| |
| // How many times will this grid reduction be entered |
| Val* entrances() const { |
| return entrances_; |
| } |
| |
| Val* buffer_stride() const { |
| return buffer_stride_; |
| } |
| |
| const ParallelTypeBitmap& threadPredicate() const { |
| return thread_predicate_; |
| } |
| |
| GroupedGridReduction* withThreadPredicate( |
| const ParallelTypeBitmap& thread_predicate) { |
| auto result = shallowCopy()->as<GroupedGridReduction>(); |
| result->thread_predicate_ = thread_predicate; |
| return result; |
| } |
| |
| private: |
| std::vector<Allocate*> reduction_buffers_; |
| Allocate* sync_buffer_ = nullptr; |
| // gridReduce has template flags for thread predicates. In order to |
| // use them, the thread predicate is held here separately from |
| // Expr::predicate_. |
| ParallelTypeBitmap thread_predicate_; |
| Val* entrance_index_ = nullptr; |
| Val* entrances_ = nullptr; |
| // Stride of reduction buffers |
| Val* buffer_stride_ = nullptr; |
| }; |
| |
| //! Grid broadcast operation |
| //! |
| //! This node is used only after lowering a fusion to explicitly mark a grid |
| //! broadcast and the buffer allocation needed to do it. |
| //! |
| //! This node provides FusionExecutor the information it needs to allocate the |
| //! broadcast and sync buffers. |
| class TORCH_CUDA_CU_API GridBroadcast final : public Expr { |
| public: |
| GridBroadcast( |
| IrBuilderPasskey passkey, |
| BroadcastOp* broadcast_op, |
| Allocate* broadcast_buffer, |
| Allocate* sync_buffer); |
| |
| Expr* shallowCopy() const override; |
| |
| BroadcastOp* broadcast_op() const { |
| return broadcast_op_; |
| } |
| |
| Allocate* broadcast_buffer() const { |
| return broadcast_buffer_; |
| } |
| |
| Allocate* sync_buffer() const { |
| return sync_buffer_; |
| } |
| |
| private: |
| BroadcastOp* broadcast_op_ = nullptr; |
| Allocate* broadcast_buffer_ = nullptr; |
| Allocate* sync_buffer_ = nullptr; |
| }; |
| |
| //! Grid welford operation |
| //! |
| //! This node is used only after lowering a fusion to explicitly mark a grid |
| //! reduction and the buffer allocation needed to do it. |
| //! |
| //! This node provides FusionExecutor the information it needs to allocate the |
| //! reduction and sync buffers. |
| //! |
| //! TODO: Make this a subclass of WelfordOp |
| class TORCH_CUDA_CU_API GridWelford final : public Expr { |
| public: |
| GridWelford( |
| IrBuilderPasskey passkey, |
| WelfordOp* welford_op, |
| Allocate* var_buffer, |
| Allocate* avg_buffer, |
| Allocate* n_buffer, |
| Allocate* sync_buffer, |
| Val* entrance_index, |
| Val* entrances); |
| |
| Expr* shallowCopy() const override; |
| |
| WelfordOp* welford_op() const { |
| return welford_op_; |
| } |
| |
| Allocate* var_buffer() const { |
| return var_buffer_; |
| } |
| |
| Allocate* avg_buffer() const { |
| return avg_buffer_; |
| } |
| |
| Allocate* N_buffer() const { |
| return n_buffer_; |
| } |
| |
| Allocate* sync_buffer() const { |
| return sync_buffer_; |
| } |
| |
| // Which instance of entering this grid reduction is this iteration? |
| Val* entrance_index() const { |
| return entrance_index_; |
| } |
| |
| // How many times will this grid reduction be entered |
| Val* entrances() const { |
| return entrances_; |
| } |
| |
| const ParallelTypeBitmap& threadPredicate() const { |
| return thread_predicate_; |
| } |
| |
| GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { |
| auto result = shallowCopy()->as<GridWelford>(); |
| result->thread_predicate_ = thread_predicate; |
| return result; |
| } |
| |
| private: |
| WelfordOp* welford_op_ = nullptr; |
| Allocate* var_buffer_ = nullptr; |
| Allocate* avg_buffer_ = nullptr; |
| Allocate* n_buffer_ = nullptr; |
| Allocate* sync_buffer_ = nullptr; |
| Val* entrance_index_ = nullptr; |
| Val* entrances_ = nullptr; |
| // gridReduce has template flags for thread predicates. In order to |
| // use them, the thread predicate is held here separately from |
| // Expr::predicate_. |
| ParallelTypeBitmap thread_predicate_; |
| }; |
| |
| class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { |
| public: |
| // input, output and init vals are vectors of triplets |
| GroupedGridWelford( |
| IrBuilderPasskey passkey, |
| std::vector<WelfordTriplet> output_vals, |
| std::vector<WelfordTriplet> input_vals, |
| std::vector<WelfordTriplet> init_vals, |
| std::array<std::vector<Allocate*>, 3> reduction_buffers, |
| Allocate* sync_buffer, |
| Val* entrance_index, |
| Val* entrances, |
| Val* buffer_stride, |
| bool is_allreduce = false); |
| |
| Expr* shallowCopy() const override; |
| |
| const std::array<std::vector<Allocate*>, 3>& reduction_buffers() const { |
| return reduction_buffers_; |
| } |
| |
| Allocate* sync_buffer() const { |
| return sync_buffer_; |
| } |
| |
| // Which instance of entering this grid reduction is this iteration? |
| Val* entrance_index() const { |
| return entrance_index_; |
| } |
| |
| // How many times will this grid reduction be entered |
| Val* entrances() const { |
| return entrances_; |
| } |
| |
| Val* buffer_stride() const { |
| return buffer_stride_; |
| } |
| |
| const ParallelTypeBitmap& threadPredicate() const { |
| return thread_predicate_; |
| } |
| |
| GroupedGridWelford* withThreadPredicate( |
| const ParallelTypeBitmap& thread_predicate) { |
| auto result = shallowCopy()->as<GroupedGridWelford>(); |
| result->thread_predicate_ = thread_predicate; |
| return result; |
| } |
| |
| private: |
| std::array<std::vector<Allocate*>, 3> reduction_buffers_; |
| Allocate* sync_buffer_ = nullptr; |
| // gridReduce has template flags for thread predicates. In order to |
| // use them, the thread predicate is held here separately from |
| // Expr::predicate_. |
| ParallelTypeBitmap thread_predicate_; |
| Val* entrance_index_ = nullptr; |
| Val* entrances_ = nullptr; |
| // Stride of reduction buffers |
| Val* buffer_stride_ = nullptr; |
| }; |
| |
| // Allocate an instance of the fused reduction class. |
| class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { |
| public: |
| explicit AllocateFusedReduction( |
| IrBuilderPasskey passkey, |
| GridReduction* grid_reduction); |
| |
| explicit AllocateFusedReduction( |
| IrBuilderPasskey passkey, |
| GridWelford* grid_welford); |
| |
| explicit AllocateFusedReduction( |
| IrBuilderPasskey passkey, |
| GroupedGridReduction* grouped_grid_reduction); |
| |
| explicit AllocateFusedReduction( |
| IrBuilderPasskey passkey, |
| GroupedGridWelford* grouped_grid_welford); |
| |
| Expr* shallowCopy() const override; |
| |
| Expr* gridExpr() const { |
| return grid_expr_; |
| } |
| |
| TensorIndex* out() const; |
| |
| const ParallelTypeBitmap& threadPredicate() const; |
| |
| private: |
| //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford |
| Expr* grid_expr_ = nullptr; |
| }; |
| |
| //! An IR node consisting of a pair of integers |
| //! to facilitate definition of 2D swizzle operators. |
| //! All swizzle 2D ops takes two inputs and outputs |
| //! an integer pair. |
| //! TODO: |
| //! currently this IR node is only allowed as input |
| //! to the new PairSelect node. In follow ups would |
| //! possibly build out to support out of line |
| //! definition of the pair alone. |
| class TORCH_CUDA_CU_API IntPair : public Val { |
| public: |
| IntPair(IrBuilderPasskey passkey); |
| }; |
| |
| //! An IR node marking selection of first or second |
| //! value from a pair of integers, e.g.: |
| //! Pair(X,Y) -> X or Y. |
| //! This IR node is used to facilitate generation |
| //! of inline 2D swizzle math. |
| class TORCH_CUDA_CU_API PairSelect : public Expr { |
| public: |
| //! Indicates which value from the input |
| //! integer pair to output. |
| enum class Selection { X = 0, Y }; |
| |
| PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); |
| |
| Expr* shallowCopy() const override; |
| |
| Val* out() const { |
| return out_; |
| } |
| |
| IntPair* in() const { |
| return in_; |
| } |
| |
| auto selection() const { |
| return selection_; |
| } |
| |
| private: |
| Val* const out_ = nullptr; |
| IntPair* const in_ = nullptr; |
| Selection selection_; |
| }; |
| |
| //! An integer IR node that will be generated |
| //! using custom integer swizzle functions |
| //! from the cuda runtime functions. |
| //! Most supported swizzle functions require |
| //! the sizes of each dimension defined so |
| //! all operators will take the extents as inputs. |
| class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { |
| public: |
| Swizzle2DInt( |
| IrBuilderPasskey, |
| IntPair* out, |
| Val* in_x, |
| Val* in_y, |
| Val* extent_x, |
| Val* extent_y, |
| Swizzle2DType swizzle_type); |
| |
| Expr* shallowCopy() const override; |
| |
| IntPair* out() const { |
| return out_; |
| } |
| |
| Val* inX() const { |
| return in_x_; |
| } |
| |
| Val* inY() const { |
| return in_y_; |
| } |
| |
| Val* extentX() const { |
| return extent_x_; |
| } |
| |
| Val* extentY() const { |
| return extent_y_; |
| } |
| |
| const auto& swizzleType() const { |
| return swizzle_type_; |
| } |
| |
| private: |
| IntPair* const out_ = nullptr; |
| |
| Val* const in_x_ = nullptr; |
| Val* const in_y_ = nullptr; |
| Val* const extent_x_ = nullptr; |
| Val* const extent_y_ = nullptr; |
| Swizzle2DType swizzle_type_; |
| }; |
| |
| } // namespace kir |
| } // namespace cuda |
| } // namespace fuser |
| } // namespace jit |
| } // namespace torch |