[functorch] (fix CI) Add batch rule for split.sizes (pytorch/functorch#952)

* add batch rule for split.sizes

* switch to doing split.sizes as a decomposition
diff --git a/functorch/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/functorch/csrc/BatchRulesDecompositions.cpp
index b84940b..aff7acf 100644
--- a/functorch/functorch/csrc/BatchRulesDecompositions.cpp
+++ b/functorch/functorch/csrc/BatchRulesDecompositions.cpp
@@ -181,6 +181,7 @@
   OP_DECOMPOSE(special_multigammaln);
   OP_DECOMPOSE(special_polygamma);
   OP_DECOMPOSE(special_softmax);
+  OP_DECOMPOSE2(split, sizes);
   OP_DECOMPOSE(square);
   OP_DECOMPOSE(numpy_T);
   OP_DECOMPOSE(reshape_as);
diff --git a/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
index 0bdbca5..8181174 100644
--- a/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
+++ b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
@@ -270,11 +270,11 @@
 std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
   if (!participatesInCurrentLevel(self)) {
     c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
-    return at::split_with_sizes(self, split_sizes, dim);
+    return split_with_sizes(self, split_sizes, dim);
   }
   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
   auto dim_physical = self_physical.getPhysicalDim(dim);
-  auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
+  auto result = split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
   self_physical.getPhysicalToLogicalMap().applyInplace(result);
   return result;
 }