blob: f62c4b4d77aae7cb7080be600c27cd5b05c9acf1 [file] [log] [blame]
#pragma once
#include <c10/util/Exception.h>
#include <algorithm>
#include <initializer_list>
#include <unordered_map>
#include <unordered_set>
#include <vector>
// For printing of the set when using a Statement as the type for the set
#include <ir_base_nodes.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
template <typename T>
std::string abstractToString(T* ptr) {
return ptr->toString();
}
template <typename T>
std::string abstractToString(T ref) {
return ref.toString();
}
} // namespace
// Vector like class that will prevent adding duplicate entries by also
// maintaing a set
template <typename T, typename Hash = std::hash<T>>
class VectorOfUniqueEntries {
public:
VectorOfUniqueEntries() = default;
VectorOfUniqueEntries(const std::initializer_list<T>& x)
: vector_(x), set_(x) {}
// Returns if a node was actually added
bool pushBack(T entry) {
if (set_.emplace(entry).second) {
vector_.push_back(entry);
return true;
}
return false;
}
// Returns if any node was added
bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) {
bool any_added = false;
for (auto entry : other) {
any_added = any_added | pushBack(entry);
}
return any_added;
}
// Returns a const vector useful for iterating on
const std::vector<T>& vector() const {
return vector_;
}
// Returns first element in vector
T front() const {
return vector_.front();
}
// Returns last element in vector
T back() const {
return vector_.back();
}
// Remove and returns the last element in vector
T popBack() {
T v = vector_.back();
set_.erase(v);
vector_.pop_back();
return v;
}
// Returns if this container is empty
bool empty() const {
return vector_.empty();
}
// Returns the number of elements in this container
size_t size() const {
return vector_.size();
}
// Returns if entry is in this vector
bool has(T entry) const {
return set_.find(entry) != set_.end();
}
// Erase given entry from the containers if
// there is a match.
void erase(T entry) {
vector_.erase(
std::remove_if(
vector_.begin(),
vector_.end(),
[entry](T val) { return val == entry; }),
vector_.end());
set_.erase(entry);
}
// Insert elements at the end of the container.
template <typename InputIt>
void insert(InputIt begin, InputIt end) {
for (auto it = begin; it != end; it++) {
pushBack(*it);
}
}
// Returns iterator pointing to the beginning of vector container
auto begin() const {
return vector().begin();
}
// Returns iterator pointing to the end of vector container
auto end() const {
return vector().end();
}
// Returns iterator pointing to the beginning of vector container
auto begin() {
return vector().begin();
}
// Returns iterator pointing to the end of vector container
auto end() {
return vector().end();
}
std::string toString() {
std::stringstream ss;
ss << "{ ";
for (auto entry : vector()) {
ss << abstractToString(entry);
if (entry != vector().back()) {
ss << "; ";
}
}
ss << " }";
return ss.str();
}
private:
std::vector<T> vector_;
std::unordered_set<T, Hash> set_;
};
//! Container class DisjointSet models equivalence relationships
//!
//! Each instance of this class keeps equivalence sets
//! DisjointSet::mapEntries(a,b) makes the full set of a and b equivalent
//! DisjointSet::*AreMapped(a,b) checks if a and b belong to the same disjoint
//! set
template <typename T, typename Hash = std::hash<T>>
class DisjointSets {
public:
DisjointSets() = default;
// Warning: returned values should never be modified. This accessor isn't
// strictly safe as VectorOfUniqueEntries is not returned as a const.
const std::
unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash>&
disjointSetMap() const {
return disjoint_set_maps_;
}
// Warning: returned values should never be modified. This accessor isn't
// strictly safe as VectorOfUniqueEntries is not returned as a const.
const std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>>&
disjointSets() const {
return disjoint_sets_;
}
// Return the entire disjoint set of provided entry
const VectorOfUniqueEntries<T, Hash>& getDisjointSetOf(T entry) const {
auto set_it = disjoint_set_maps_.find(entry);
TORCH_INTERNAL_ASSERT(
set_it != disjoint_set_maps_.end(),
"Could not find entry for ",
entry->toString());
return *(set_it->second);
}
// Initializes a new set for provided entry
//
// TODO: Return iterator
void initializeSet(T entry) {
if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) {
return;
}
disjoint_sets_.push_back(
std::make_shared<VectorOfUniqueEntries<T, Hash>>());
disjoint_sets_.back()->pushBack(entry);
disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back()));
}
// Adds all of the disjoint set belonging to entry1 to the disjoint set
// belonging to entry0, maps all entries of disjoint set belonging to entry1
// to entry0, removes original disjoint set belonging to entry1.
void mapEntries(T entry0, T entry1) {
auto set_it_0 = disjoint_set_maps_.find(entry0);
auto set_it_1 = disjoint_set_maps_.find(entry1);
// Track if we need to reset iterators, optimize for case where both entries
// exist
bool invalid_iterators = false;
if (set_it_0 == disjoint_set_maps_.end()) {
initializeSet(entry0);
invalid_iterators = true;
}
if (set_it_1 == disjoint_set_maps_.end()) {
initializeSet(entry1);
invalid_iterators = true;
}
// TODO: We can avoid refinding one iterator if initialize set returns an
// iterator, though if we insert entry1 we'd have to refind entry0 as it
// could invalidate all iterators
if (invalid_iterators) {
set_it_0 = disjoint_set_maps_.find(entry0);
set_it_1 = disjoint_set_maps_.find(entry1);
}
auto set0_shared_ptr = set_it_0->second;
auto set1_shared_ptr = set_it_1->second;
// If the sets are already the same, do nothing
if (set0_shared_ptr == set1_shared_ptr) {
return;
}
// Place everything in set1 into set0 and remap all entries in set1 to set0
for (auto entry : set1_shared_ptr->vector()) {
set0_shared_ptr->pushBack(entry);
disjoint_set_maps_[entry] = set0_shared_ptr;
}
// set1 no longer needed as its entries are copied into set0
disjoint_sets_.erase(std::find(
disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr));
}
// Will assert if provided entry0 is not in any disjoint set, otherwise
// returns if entry0 and entry1 are in the same disjoint set.
bool strictAreMapped(T entry0, T entry1) const {
auto entry_it = disjointSetMap().find(entry0);
TORCH_INTERNAL_ASSERT(
entry_it != disjointSetMap().end(),
"Strict mapping failed on element: ",
abstractToString(entry0),
" either an error occurred, or non strict mapping should have been used.");
return entry_it->second->has(entry1);
}
// If entry0 doesn't have a disjoint set returns false, otherwise returns if
// entry0 and entry1 are in the same disjoint set.
bool permissiveAreMapped(T entry0, T entry1) const {
auto entry_it = disjointSetMap().find(entry0);
if (entry_it == disjointSetMap().end()) {
return false;
}
return entry_it->second->has(entry1);
}
// Returns if a set exists with provided entry
bool mappingExists(T entry) const {
return disjoint_set_maps_.find(entry) != disjoint_set_maps_.end();
}
// Returns a deterministic list of all entries that have been added to any
// disjoint set.
//
// Warning: constructed on every call, consider caching result.
VectorOfUniqueEntries<T, Hash> getAllElements() const {
VectorOfUniqueEntries<T, Hash> all_elements;
for (auto set : disjoint_sets_) {
for (auto entry : set->vector()) {
all_elements.pushBack(entry);
}
}
return all_elements;
}
// Completely clears all disjoint sets
void clear() {
disjoint_set_maps_.clear();
disjoint_sets_.clear();
}
std::string toString() const {
std::stringstream ss;
ss << "disjoint sets{\n";
const std::string sep(" ");
for (auto s_ptr : disjoint_sets_) {
auto& set = *s_ptr;
ss << sep << "{\n";
for (auto entry : set.vector()) {
ss << sep << sep << abstractToString(entry) << "\n";
}
ss << sep << "}\n";
}
ss << "}";
return ss.str();
}
private:
// Disjoint sets
std::unordered_map<T, std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, Hash>
disjoint_set_maps_;
// Keep a list of disjoint_sets that's deterministic to iterate over
std::vector<std::shared_ptr<VectorOfUniqueEntries<T, Hash>>> disjoint_sets_;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch