[functorch] [port] `expand` to new api (pytorch/functorch#161)

* [port] expand to new api

* update code

* update code

* fix incorrect merge

* retrigger CI
diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp
index ef383d5..1768c53 100644
--- a/functorch/functorch/csrc/BatchRulesViews.cpp
+++ b/functorch/functorch/csrc/BatchRulesViews.cpp
@@ -354,6 +354,40 @@
   return std::make_tuple(std::move(result), 0);
 }
 
+std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
+    const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
+{
+  auto self_dim = self.dim();
+  TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
+              "expand: the number of sizes provided (", size.size(), ") ",
+              "must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");
+
+  auto self_ = moveBatchDimToFront(self, self_bdim);
+  auto self_sizes = self_.sizes();
+  auto batch_size = self_sizes[0];
+
+  c10::SmallBuffer<int64_t, 5> size_(size.size() + 1);
+  size_[0] = batch_size;
+  std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
+
+  // Here, we know we are expanding a (logical) tensor to a larger number
+  // of dimensions. We have to be careful because we can't call expand directly
+  // due to the presence of batch dimensions.
+  //
+  // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
+  // The result should be a tensor of size [B0, 2, 3].
+  // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
+  // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
+  // then expand.
+  auto extra_dims = size.size() - (self_dim - 1);
+  VmapDimVector view_shape(size_.size(), /*init_value*/1);
+  view_shape[0] = batch_size;
+  std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
+            view_shape.begin() + 1 + extra_dims);
+
+  return std::make_tuple(self_.view(view_shape).expand(size_, implicit), 0);
+}
+
 TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
   VMAP_SUPPORT("diag", diag_batch_rule);
   VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -375,6 +409,7 @@
   VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
   VMAP_SUPPORT("select_backward", select_backward_batch_rule);
   VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
+  VMAP_SUPPORT("expand", expand_batch_rule);
 }
 
 }}
diff --git a/functorch/functorch/csrc/BatchingRegistrations.cpp b/functorch/functorch/csrc/BatchingRegistrations.cpp
index bf402a0..fdf0a74 100644
--- a/functorch/functorch/csrc/BatchingRegistrations.cpp
+++ b/functorch/functorch/csrc/BatchingRegistrations.cpp
@@ -132,49 +132,6 @@
   return true;
 }
 
-Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
-  if (!participatesInCurrentLevel(self)) {
-    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
-    return self.expand(size, implicit);
-  }
-
-  auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
-  auto size_physical = self_physical.getPhysicalShape(size);
-  auto self_physical_dim = self_physical.tensor().dim();
-
-  TORCH_CHECK((uint64_t)self_physical_dim <= size_physical.size(),
-       "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
-       "must be greater or equal to the number of dimensions in the tensor (",
-       /*logical dim*/self.dim(), ")");
-
-  if ((uint64_t)self_physical_dim == size_physical.size()) {
-    auto result = self_physical.tensor().expand(size_physical, implicit);
-    return self_physical.getPhysicalToLogicalMap().apply(result);
-  }
-
-  TORCH_INTERNAL_ASSERT((uint64_t)self_physical_dim < size_physical.size());
-  // Here, we know we are expanding a (logical) tensor to a larger number
-  // of dimensions. We have to be careful because we can't call expand directly
-  // due to the presence of batch dimensions.
-  //
-  // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
-  // The result should be a tensor of size [B0, 2, 3].
-  // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
-  // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
-  // then expand.
-  auto self_physical_size = self_physical.tensor().sizes();
-  auto extra_dims = size_physical.size() - self_physical_dim;
-  VmapDimVector view_shape(size_physical.size(), 1);
-  std::copy(self_physical_size.begin(),
-            self_physical_size.begin() + self_physical.numBatchDims(),
-            view_shape.begin());
-  std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
-            self_physical_size.end(),
-            view_shape.begin() + self_physical.numBatchDims() + extra_dims);
-  auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
-  return self_physical.getPhysicalToLogicalMap().apply(result);
-}
-
 std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
   if (!participatesInCurrentLevel(self)) {
     c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -1001,7 +958,6 @@
   // m.impl("chunk", chunk_batching_rule);
   m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
   m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
-  m.impl("expand", expand_batching_rule);
   m.impl("movedim.intlist", movedim_batching_rule);
   m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
   // NB: static_cast because there's another variant of narrow. However, we don't