blob: 00288136a2a93c5b6b4165364090779d08675a94 [file] [log] [blame]
#pragma once
#include <iter_visitor.h>
#include <root_domain_map.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
/*
* Index compute takes in a list of indices typically generated from the
* surrounding for loop nest. The number of indicies are intended to match the
* number of dimensions of the incomming TensorView which may have less or more
* dimensions than its root due to split/merge operations.
* Split/merge operations are then replayed backwards produce resulting
* indices (based on input indices) that match the root dimension.
*
* For example with GLOBAL tensor:
* TV[I, K]
* TV[Io, Ii{4}, K] = TV.split(I, factor=4)
* ALLOC: NONE
* INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
* FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
* PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
*
*
* For example with SHARED tensor:
*
* global_TV[I, K]
* global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
* smem_TV.compute_at(global_TV, 1)
* global_TV.parallelize(1, threadIDx.x)
*
* ALLOC: alloc(smem_TV, 4 x K)
* INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
* PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
* global
*
*
* For example with LOCAL tensor:
* global_TV[I, K, L]
* global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
* reg_TV.compute_at(global_TV, 2)
* global_TV.parallelize(1, threadIDx.x)
* global_TV{i, j, k, l} -> { i * 4 + j, k, l }
* global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
*
* ALLOC: alloc(reg_TV, K x L)
* INDEX: {k, l} -> {k, l}
* FLATTENED_INDEX: {k, l} -> {k * L + l}
* PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
*
* These indices can then be flattened later based on strides.
*/
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
class ContigIDs;
class LoopIndexing;
struct IndexFromIdGraph;
class IndexCompute : public BackwardVisitor {
protected:
using BackwardVisitor::handle;
void handle(Split*) override;
void handle(Merge*) override;
void handle(Expr*) override;
void handle(Swizzle2D*) override;
// return extent_map_[id] if exists, else return id->extent()
Val* getExtent(IterDomain* id) const;
//! True if a domain is not used to index
bool isZero(IterDomain* id) const;
//! True if any dependent of a domain is not used to index
bool hasZeroMerged(IterDomain* id) const;
//! Returns the concrete ID from the compute at EXACT mode map if
//! concrete_id_pass == true, otherwise returns id passed in.
//! Helps unify the expr handling logic in reference domain and concrete id
//! based traversal.
IterDomain* maybeGetExactMapConcreteID(IterDomain* id);
//! (Concrete indexing pass only)
//! Collect permissive index binding from the given expression.
//! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);
//! (Concrete indexing pass only)
//! Iterate through id_expr's input and pull index vals from permissive
//! map, when both of the following are true:
//! 1. the output id is missing in index_map_.
//! 2. the output id is found in permissive map.
void updateIndexMapFromPermissiveMap(const Expr* id_expr);
// Tensor domain we're mapping back to root
const TensorDomain* td_; // NOLINT
// Map we update as we propagate backward, containing all IDs in the
// propagation. Initial indices are mapped with this map at tv->domain()
// and are back propagated to tv->getRootDomain(). This index_map_ keeps the
// indices at intermediate IterDomain's in that back propagation.
std::unordered_map<IterDomain*, Val*> index_map_; // NOLINT
// Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
// producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
// the extent I0*I1. Also contains updated extents if we merge in a 0 index.
// See zero_merged_in_.
std::unordered_map<IterDomain*, Val*> extent_map_; // NOLINT
// Keeps track of domains that do not contribute to indexing
std::unordered_set<IterDomain*> zero_domains_; // NOLINT
// This set keeps track of IterDomain's that have had a zero index merged into
// them. This happens if we do something like tv->axis(0)->split(4) then
// tv->computeAt(1, ...) if this tensor is in smem or lmem the backward
// indexing would be (0, i) then when we do the backward computation that zero
// and i would attempt to be merged together. We handle indices like these
// specially.
std::unordered_set<IterDomain*> zero_merged_in_;
// IDs that are a result of contiguous merges
std::unordered_set<IterDomain*> contig_ids_;
// Map from root to indexed domains
std::unordered_map<IterDomain*, IterDomain*> root_to_indexed_id_;
// Mentions if we should propagate an index down a particular IterDomain path
// if there's an option
std::unordered_set<IterDomain*> preferred_paths_;
// Map from IterDomains to halo-extended extents
std::unordered_map<IterDomain*, Val*> halo_extent_map_;
// Temporary flag which tells IndexCompute to use concrete id's from the exact
// map rather than the actual IDs used in the ID expressions.
bool concrete_id_pass_ = false;
// Mode of swizzle that are activated in this index compute
// instance. Will treat swizzles of different mode as no-op.
// Currently data mode swizzles are handled same as before in IndexSwizzle
// pass, while loop mode swizzles are handled early on in concrete indexing
// pass. See also [Note on swizzle mode]
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
// (Concrete id pass only)
// Contains the indexing math that could be resolved with only the
// iterdomains on the right of the consumer_tv's ca axis, i.e. the
// ones that corresponding to the loops that consumer_tv would not
// share with any of its consumers.
// These indexing vals should be kept separate from index_map_ and
// should only be used when the indexing traversal follows the
// order defined in LoopIndexingAnalysis::traverseFromDomainVals.
std::unordered_map<IterDomain*, Val*> permissive_index_map_;
public:
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
return index_map_;
}
const std::unordered_map<IterDomain*, Val*>& extentMap() const {
return extent_map_;
}
const std::unordered_set<IterDomain*>& zeroDomains() const {
return zero_domains_;
}
const std::unordered_set<IterDomain*>& zeroMergedIn() const {
return zero_merged_in_;
}
const std::unordered_map<IterDomain*, IterDomain*>& rootToContigID() const {
return root_to_indexed_id_;
}
// Propagate back from _td using initial_index_map
IndexCompute(
const TensorDomain* _td,
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_map<IterDomain*, Val*> _extent_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> _zero_merged_in,
std::unordered_set<IterDomain*> preferred_paths = {},
std::unordered_map<IterDomain*, Val*> halo_extent_map = {});
IndexCompute(
const TensorDomain* _td,
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_map<IterDomain*, Val*> _extent_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> _zero_merged_in,
const ContigIDs& contig_finder,
std::unordered_set<IterDomain*> preferred_paths = {},
std::unordered_map<IterDomain*, Val*> halo_extent_map = {});
// Entry point used for using concrete id based traversal. This traversal is
// assumed to start at leaf IDs provided by initial_index_map.
IndexCompute(
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> preferred_paths,
std::unordered_map<IterDomain*, Val*> concrete_halo_extent_map);
// Updates index_map, extent_map, and zero_merged_in based on id_map and
// returns a new IndexCompute ready to be used.
IndexCompute updateIndexCompute(
const TensorDomain* new_td,
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
const ContigIDs& contig_finder) const;
// Interface to run index traversal through loop indexing analysis result to
// be used with the entry point for concrete id based traversal.
void run(const LoopIndexing& loop_indexing);
virtual void run();
};
//! Apply swizzle and update root indices accordingly
class IndexSwizzle : public IndexCompute {
public:
IndexSwizzle(
const TensorView* tv,
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_map<IterDomain*, Val*> extent_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> zero_merged_in);
IndexSwizzle(
const TensorView* tv,
const TensorDomain* domain,
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_map<IterDomain*, Val*> extent_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> zero_merged_in);
void run() override;
protected:
using IndexCompute::handle;
void handle(Expr* e) override;
void handle(Swizzle2D* swizzle_2d) override;
private:
const TensorView* tv_ = nullptr;
SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
std::vector<IterDomain*> ids_to_swizzle_;
std::unordered_set<IterDomain*> swizzled_ids_;
};
//! Predicate information of a root or contiguous merged domain
class RootPredicateInfo {
friend class Index;
public:
const auto& startPredicate() const {
return start_predicate_;
}
auto& startPredicate() {
return start_predicate_;
}
const auto& startOffset() const {
return start_offset_;
}
const auto& stopPredicate() const {
return stop_predicate_;
}
const auto& stopOffset() const {
return stop_offset_;
}
const auto& rootIds() const {
return root_ids_;
}
//! Return a false RootPredicateInfo, i.e., both start and stop
//! predicates are false.
static RootPredicateInfo getFalseInfo();
private:
// prdicate for lower end
Bool* start_predicate_ = nullptr;
// prdicate for upper end
Bool* stop_predicate_ = nullptr;
// Offset of the start predicate
Val* start_offset_ = nullptr;
// Offset of the stop predicate
Val* stop_offset_ = nullptr;
// Track which roots have been handled by the generated predicates
std::unordered_set<IterDomain*> root_ids_;
};
// Simple interface for IndexCompute
// If getComputeAtAxis and more generally TensorView const model is fixed, we
// can make the below tensorviews const.
class Index {
private:
// Producer indexing if it's in shared or local memory
static std::vector<Val*> getNonGlobalProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer indexing if it's in shared or local memory
static std::vector<Val*> getNonGlobalConsumerStridedIndices(
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Producer if it's in global memory
static std::vector<Val*> getGlobalProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer indexing if it's in global memory
static std::vector<Val*> getGlobalConsumerStridedIndices(
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// get the strides of a tensor used for the index lowering
static std::vector<Val*> getStrides(const TensorView* tv);
// get the root indices of a tensor used for the index lowering
static std::vector<Val*> getRootIndices(
const TensorView* tv,
const std::vector<kir::ForLoop*>& loops,
const IndexFromIdGraph& index_from_id_graph);
public:
// Indexing functions
// Consumer = Producer
// i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
// Producer indexing dispatch
static kir::TensorIndex* getProducerIndex(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer index dispatch
static kir::TensorIndex* getConsumerIndex(
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
//! Returns a vector of strided indices mapped onto the (rfactor)
//! root domain of a producer tensor. The size of the returned
//! vector is guaranteed to be equal to the number of axes of the
//! indexing root domain.
static std::vector<Val*> getProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
//! Returns a vector of strided indices mapped onto the (rfactor)
//! root domain of a consumer tensor. The size of the returned
//! vector is guaranteed to be equal to the number of axes of the
//! indexing root domain.
static std::vector<Val*> getConsumerStridedIndices(
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
//! Returns the logical index linearized from a multi-dimension address into a
//! linear memory address a consumer tensor. The returned index is intended to
//! be used for the computation of some tensor factories, such as: arange and
//! rand (for Philox pseudo random sequences)
static std::vector<Val*> getLinearLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops);
//! Returns a vector of logical indices mapped onto the (rfactor)
//! root domain of a consumer tensor. The returned index is intended
//! to be used for the computation of some tensor factories, such as:
//! eye
static std::vector<Val*> getPerDimLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops);
//! Take a consumer tensorview and loop nest and generates predicates
//! associated with the concrete roots of the loop nest. Returns a list of
//! predicates, and a list of concrete roots they're associated with. It
//! is assumed that no predicate is required if index[i] is an index
//! directly from a for loop. This will not catch all cases if we actually
//! have static size information for example:
//!
//! TV[I].split(4)
//! would produce the code:
//! for(i : I/4)
//! for(j : 4)
//! if( i * 4 + j < TV.size(0))
//! TV[i * 4 + j]...
//!
//! However if we had TV.size[0] = 16 at "compile time" then we wouldn't
//! need the predicate. This will be caught by canOmitPredicate in the
//! predicate lowering
//!
//! unswitch_or_vec_loop is the for loop to start the unswitch like
//! predicate, this is not a bool value as if we have an unswitch loop
//! with a vectorized loop inside, we only want to base the "unswitch"
//! like predicate on the vectorized loop.
static std::vector<RootPredicateInfo> getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* unswitch_or_vec_loop,
bool padding_predicate);
};
// Used for local and shared index mapping. Returns a map from loops
// to loop indices as well as a set of loops that do not contribute to
// indexing.
// TODO: could be cleaned up further.
std::pair<
std::unordered_map<kir::ForLoop*, Val*>,
std::unordered_set<kir::ForLoop*>>
indexMapFromTV(
const TensorView* tv,
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* alloc_loop,
bool as_consumer,
kir::ForLoop* double_buffer_loop = nullptr);
//! Set "pragma unroll" required for loops that indexing of Local
//! tensors depends on.
//!
//! \param tv Indexed tensor
//! \param alloc_loop Allocation loop of tv
//! \param loops The current loop structure
//! \param id_map Producer-to-consumer map in case of indexing as producer
void ensureStaticIndexing(
const TensorView* tv,
kir::ForLoop* alloc_loop,
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, IterDomain*>& id_map = {});
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch