Resolve comments in #44354. (#45150)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45150
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D23846796
Pulled By: ailzhang
fbshipit-source-id: 7bef89d833848ac3f8993c4c037acf1d4f2ca674
diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp
index b5d552e..f84352e 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.cpp
+++ b/aten/src/ATen/core/boxing/KernelFunction.cpp
@@ -22,6 +22,7 @@
void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, Stack*) {
TORCH_INTERNAL_ASSERT(0,
op.operator_name(), " has kernels registered to both Math and a backend mapped to AutogradOther. "
+ "This makes the backend kernel unreachable (see Note [Ambiguity in AutogradOther kernel]). "
"If it's intended to override Math kernel behavior, please open an issue to request a dedicated "
"Autograd dispatch key for the backend.");
}
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index 5fa379e..0942659 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -157,10 +157,9 @@
}
bool OperatorEntry::hasKernelForDispatchKeySet(DispatchKeySet ks) const {
- for (auto k : ks) {
- if (kernels_.find(k) != kernels_.end()) {
- return true;
- }
+ TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
+ for (auto& kv : kernels_) {
+ if (ks.has(kv.first)) return true;
}
return false;
}
@@ -196,6 +195,9 @@
// In the past we directly call into backends(filled with catchAll) after BackendSelect.
// Now that we first call Autograd backend keys after BackendSelect, we should fill those
// with catchAll as well.
+ // The implementation of (2.1) & (2.3) relies on the invariant that for a given backend,
+ // `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the
+ // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_]
// (3) Use fallthrough kernel that are registered as fallback.
// (4) Use catchAll kernel if available
// Alias Key Precedence:
@@ -272,7 +274,8 @@
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
updateDispatchTableEntry_(dispatcher, k);
}
- // Registering to backend key might affect computed entry at its Autograd backend key due to 2.2.
+ // Note [Refresh Runtime Autograd entries in dispatchTable_]
+ // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
}