[functorch] Selectively enable dispatch on kBatchedKey (pytorch/functorch#63)

This PR makes it so that dispatch on kBatchedKey only can happen if there are
tensors batched at the current level. Otherwise, kBatchedKey is excluded
(even if there are BatchedTensors!).

To find tensors batched at the current level, we check:
- all tensor arguments
- we peek into all TensorLists
- we peek into all Tensor?[].
the above bullet points should be sufficient.

Dispatch for kVmapModeKey is not affected.

Test Plan:
- run all tests
- removed the special case in dot_batch_rule and added a test
diff --git a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
index a8419f5..268d87a 100644
--- a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
+++ b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
@@ -40,8 +40,6 @@
   auto B_ = moveBatchDimToFront(B, B_bdim);
   if (A_bdim && B_bdim) {
     return std::make_tuple(at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0);
-  } else if (!A_bdim && !B_bdim) {
-    return std::make_tuple(at::dot(A_, B_), nullopt);
   } else {
     return std::make_tuple(at::matmul(A_, B_.t()), 0);
   }
diff --git a/functorch/functorch/csrc/DynamicLayer.cpp b/functorch/functorch/csrc/DynamicLayer.cpp
index 88c0dfd..0444f65 100644
--- a/functorch/functorch/csrc/DynamicLayer.cpp
+++ b/functorch/functorch/csrc/DynamicLayer.cpp
@@ -230,6 +230,45 @@
   }
 }
 
+static bool allTensors(
+    ArrayRef<IValue> args,
+    std::function<bool(const Tensor&)> pred) {
+  for (const auto& ivalue : args) {
+    // Tensor?[] translates to a c10::List<IValue> so we need to peek inside List
+    if (ivalue.isList()) {
+      for (const auto& elt : ivalue.toListRef()) {
+        if (elt.isTensor() && !pred(elt.toTensor())) {
+            return false;
+        }
+      }
+      continue;
+    }
+    if (ivalue.isTensorList()) {
+      for (const auto& elt : ivalue.toTensorList()) {
+        if (!pred(elt)) {
+          return false;
+        }
+      }
+      continue;
+    }
+    TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    if (!pred(ivalue.toTensor())) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static bool anyTensors(
+    ArrayRef<IValue> args,
+    std::function<bool(const Tensor&)> pred) {
+  // Demorgan's law
+  return !allTensors(args, [&](const Tensor& self) { return !pred(self); });
+}
+
 constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
   kDynamicLayerFrontModeKey,
   kDynamicLayerBackModeKey,
@@ -252,6 +291,19 @@
       });
 }
 
+static bool batchedAtCurrentLevel(const Tensor& tensor) {
+  auto& dynamicLayerStack = dynamicLayerStackAccessor();
+  auto layer = dynamicLayerStack.back();
+  auto level = layer.layerId();
+
+  auto* batched = maybeGetBatchedImpl(tensor);
+  if (!batched) {
+    return false;
+  }
+  auto batched_at_level = batched->bdims().back().level();
+  return batched_at_level == level;
+}
+
 void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
   auto& dynamicLayerStack = dynamicLayerStackAccessor();
 #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
@@ -282,10 +334,13 @@
   if (layer.key() == DispatchKey::Autograd) {
     exclude = exclude - autograd_dispatch_keyset;
     exclude = exclude.remove(DispatchKey::ADInplaceOrView);
-  // } else if (layer.key() == DispatchKey::Batched) {
-  //   exclude = exclude.remove(DispatchKey::Batched);
   } else if (layer.key() == kBatchedKey) {
-    exclude = exclude.remove(kBatchedKey);
+    // Only enable dispatch on kBatchedKey if there are tensors batched
+    // at the current level.
+    const auto args = torch::jit::last(stack, op.schema().arguments().size());
+    if (anyTensors(args, batchedAtCurrentLevel)) {
+      exclude = exclude.remove(kBatchedKey);
+    }
     include = include.add(kVmapModeKey);
   } else {
     TORCH_INTERNAL_ASSERT(false);
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 0840454..7810877 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -932,6 +932,18 @@
         y = reshape_dim_outof(-1, 6, x)
         self.assertEqual(y, x.reshape(12, 12, 6, 2))
 
+    def test_batch_rule_does_not_need_to_handle_no_batched_input(self):
+        def f(x, y):
+            res = torch.dot(y, torch.ones(2))
+            return x + res
+
+        x = torch.randn(7, 5)
+        y = torch.randn(3, 2)
+        out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
+        expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
+        self.assertEqual(out, expected)
+
+
 def slice_inputs(inputs, bdims, i):
     result = []
     for inp, bdim in zip(inputs, bdims):