extract dispatch keys from optional Tensors (unboxed) (#58296)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58296
Test Plan: Imported from OSS
Reviewed By: Chillee
Differential Revision: D28436822
Pulled By: bhosmer
fbshipit-source-id: 8031c9a3c121483dd0e5ed7b8b165952477108e4
diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
index 7011686..9d3cc67 100644
--- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
+++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -55,6 +55,11 @@
void operator()(const at::Tensor& x) {
ts = ts | x.key_set();
}
+ void operator()(c10::optional<at::Tensor> x) {
+ if (x.has_value()) {
+ ts = ts | x->key_set();
+ }
+ }
void operator()(at::ArrayRef<at::Tensor> xs) {
for (const auto& x : xs) {
ts = ts | x.key_set();