blob: 1804604b8e995f639fcb3cc227e6201ea7a2520f [file] [log] [blame]
#include <c10/core/DispatchKeySet.h>
namespace c10 {
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradCPU,
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,
DispatchKey::AutogradPrivateUse3,
DispatchKey::AutogradOther,
});
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset;
case DispatchKey::Undefined:
return DispatchKeySet();
default:
return DispatchKeySet(t);
}
}
template <std::size_t... Is>
constexpr auto make_array_from_sequence(std::index_sequence<Is...>) {
return std::array<DispatchKey, sizeof...(Is)>{static_cast<DispatchKey>(Is)...};
}
constexpr auto runtime_dispatch_keys = make_array_from_sequence(
std::make_index_sequence<static_cast<uint8_t>(DispatchKey::NumDispatchKeys)>{});
// Create singleton for alias keys separately to make sure we don't
// accidentally support DispatchKey::NumDispatchKeys in std::array.
constexpr std::array<DispatchKey, 7> autograd_dispatch_keys {
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
DispatchKey::AutogradPrivateUse1, DispatchKey::AutogradPrivateUse2,
DispatchKey::AutogradPrivateUse3, DispatchKey::AutogradOther};
ArrayRef<DispatchKey> getRuntimeDispatchKeys(DispatchKey k) {
if (isAliasDispatchKey(k)) {
switch (k) {
case DispatchKey::Autograd:
return autograd_dispatch_keys;
default:
TORCH_INTERNAL_ASSERT(false, "Unable to resolve alias dispatch key");
}
}
return c10::ArrayRef<DispatchKey>(runtime_dispatch_keys).slice(static_cast<uint8_t>(k), 1);
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
DispatchKey tid;
bool first = true;
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
if (!first) {
os << ", ";
}
os << tid;
ts = ts.remove(tid);
first = false;
}
os << ")";
return os;
}
}