[functorch] [port] `expand` to new api (pytorch/functorch#161)
* [port] expand to new api
* update code
* update code
* fix incorrect merge
* retrigger CI
diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp
index ef383d5..1768c53 100644
--- a/functorch/functorch/csrc/BatchRulesViews.cpp
+++ b/functorch/functorch/csrc/BatchRulesViews.cpp
@@ -354,6 +354,40 @@
return std::make_tuple(std::move(result), 0);
}
+std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
+ const Tensor &self, optional<int64_t> self_bdim, IntArrayRef size, bool implicit)
+{
+ auto self_dim = self.dim();
+ TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
+ "expand: the number of sizes provided (", size.size(), ") ",
+ "must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");
+
+ auto self_ = moveBatchDimToFront(self, self_bdim);
+ auto self_sizes = self_.sizes();
+ auto batch_size = self_sizes[0];
+
+ c10::SmallBuffer<int64_t, 5> size_(size.size() + 1);
+ size_[0] = batch_size;
+ std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
+
+ // Here, we know we are expanding a (logical) tensor to a larger number
+ // of dimensions. We have to be careful because we can't call expand directly
+ // due to the presence of batch dimensions.
+ //
+ // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
+ // The result should be a tensor of size [B0, 2, 3].
+ // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
+ // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
+ // then expand.
+ auto extra_dims = size.size() - (self_dim - 1);
+ VmapDimVector view_shape(size_.size(), /*init_value*/1);
+ view_shape[0] = batch_size;
+ std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
+ view_shape.begin() + 1 + extra_dims);
+
+ return std::make_tuple(self_.view(view_shape).expand(size_, implicit), 0);
+}
+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -375,6 +409,7 @@
VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
VMAP_SUPPORT("select_backward", select_backward_batch_rule);
VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
+ VMAP_SUPPORT("expand", expand_batch_rule);
}
}}
diff --git a/functorch/functorch/csrc/BatchingRegistrations.cpp b/functorch/functorch/csrc/BatchingRegistrations.cpp
index bf402a0..fdf0a74 100644
--- a/functorch/functorch/csrc/BatchingRegistrations.cpp
+++ b/functorch/functorch/csrc/BatchingRegistrations.cpp
@@ -132,49 +132,6 @@
return true;
}
-Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
- if (!participatesInCurrentLevel(self)) {
- c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
- return self.expand(size, implicit);
- }
-
- auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
- auto size_physical = self_physical.getPhysicalShape(size);
- auto self_physical_dim = self_physical.tensor().dim();
-
- TORCH_CHECK((uint64_t)self_physical_dim <= size_physical.size(),
- "expand: the number of sizes provided (", /*logical*/size.size(), ") ",
- "must be greater or equal to the number of dimensions in the tensor (",
- /*logical dim*/self.dim(), ")");
-
- if ((uint64_t)self_physical_dim == size_physical.size()) {
- auto result = self_physical.tensor().expand(size_physical, implicit);
- return self_physical.getPhysicalToLogicalMap().apply(result);
- }
-
- TORCH_INTERNAL_ASSERT((uint64_t)self_physical_dim < size_physical.size());
- // Here, we know we are expanding a (logical) tensor to a larger number
- // of dimensions. We have to be careful because we can't call expand directly
- // due to the presence of batch dimensions.
- //
- // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
- // The result should be a tensor of size [B0, 2, 3].
- // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
- // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
- // then expand.
- auto self_physical_size = self_physical.tensor().sizes();
- auto extra_dims = size_physical.size() - self_physical_dim;
- VmapDimVector view_shape(size_physical.size(), 1);
- std::copy(self_physical_size.begin(),
- self_physical_size.begin() + self_physical.numBatchDims(),
- view_shape.begin());
- std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
- self_physical_size.end(),
- view_shape.begin() + self_physical.numBatchDims() + extra_dims);
- auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
- return self_physical.getPhysicalToLogicalMap().apply(result);
-}
-
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -1001,7 +958,6 @@
// m.impl("chunk", chunk_batching_rule);
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
- m.impl("expand", expand_batching_rule);
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
// NB: static_cast because there's another variant of narrow. However, we don't