[functorch] (fix CI) Add batch rule for split.sizes (pytorch/functorch#952)
* add batch rule for split.sizes
* switch to doing split.sizes as a decomposition
diff --git a/functorch/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/functorch/csrc/BatchRulesDecompositions.cpp
index b84940b..aff7acf 100644
--- a/functorch/functorch/csrc/BatchRulesDecompositions.cpp
+++ b/functorch/functorch/csrc/BatchRulesDecompositions.cpp
@@ -181,6 +181,7 @@
OP_DECOMPOSE(special_multigammaln);
OP_DECOMPOSE(special_polygamma);
OP_DECOMPOSE(special_softmax);
+ OP_DECOMPOSE2(split, sizes);
OP_DECOMPOSE(square);
OP_DECOMPOSE(numpy_T);
OP_DECOMPOSE(reshape_as);
diff --git a/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
index 0bdbca5..8181174 100644
--- a/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
+++ b/functorch/functorch/csrc/LegacyBatchingRegistrations.cpp
@@ -270,11 +270,11 @@
std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
- return at::split_with_sizes(self, split_sizes, dim);
+ return split_with_sizes(self, split_sizes, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
- auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
+ auto result = split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}