blob: 2a29b967f37aad884d984d5cb5a781ab6ac45724 [file] [log] [blame]
#include <c10/core/DispatchKeySet.h>
namespace c10 {
// backend_dispatch_keyset should include all runtime backend keys.
// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset
// NestedTensor has been explicitly removed due to incompatibility with some
// kernels, such as structured kernels, that use the DefaultBackend key.
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
DispatchKey::CUDA,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
});
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset;
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
default:
return DispatchKeySet(t);
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
// for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
return DispatchKeySet(DispatchKey::CPU);
case DispatchKey::AutogradCUDA:
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradMLC:
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:
return DispatchKeySet(DispatchKey::PrivateUse2);
case DispatchKey::AutogradPrivateUse3:
return DispatchKeySet(DispatchKey::PrivateUse3);
case DispatchKey::AutogradOther:
return autogradother_backends;
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet({
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
DispatchKey tid;
bool first = true;
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
if (!first) {
os << ", ";
}
os << tid;
ts = ts.remove(tid);
first = false;
}
os << ")";
return os;
}
}