[functorch] Exclude List[Optional[Tensor]] from the batched fallback
diff --git a/functorch/functorch/csrc/BatchedFallback.cpp b/functorch/functorch/csrc/BatchedFallback.cpp
index 73b3675..8d3d8c6 100644
--- a/functorch/functorch/csrc/BatchedFallback.cpp
+++ b/functorch/functorch/csrc/BatchedFallback.cpp
@@ -57,7 +57,13 @@
return std::any_of(
schema.arguments().begin(),
schema.arguments().end(),
- [] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
+ [] (const Argument& arg) {
+ static auto ListOfOptionalTensors = ListType::create(OptionalType::ofTensor());
+ if (arg.type()->isSubtypeOf(ListType::ofTensors())) {
+ return true;
+ }
+ return arg.type()->isSubtypeOf(ListOfOptionalTensors);
+ });
}
// Returns if an operator is in-place. An operator is inplace if: