[functorch] Some view backward batch rules
diff --git a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
index 6f0f5d8..0d5815b 100644
--- a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
+++ b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
@@ -704,9 +704,6 @@
STOP_DECOMPOSE(masked_select_backward);
STOP_DECOMPOSE(matrix_exp_backward);
STOP_DECOMPOSE(trace_backward);
- STOP_DECOMPOSE(slice_backward);
- STOP_DECOMPOSE(select_backward);
- STOP_DECOMPOSE(diagonal_backward);
STOP_DECOMPOSE(cummaxmin_backward);
STOP_DECOMPOSE(cumprod_backward);
STOP_DECOMPOSE(diag_backward);
diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp
index 342aa45..6eb1412 100644
--- a/functorch/functorch/csrc/BatchRulesViews.cpp
+++ b/functorch/functorch/csrc/BatchRulesViews.cpp
@@ -10,6 +10,7 @@
#include <functorch/csrc/PlumbingHelper.h>
#include <functorch/csrc/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/util/SmallBuffer.h>
namespace at { namespace functorch {
@@ -271,6 +272,46 @@
return std::make_tuple(result, 0);
}
+std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
+ dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::diagonal_backward(grad_input_, input_sizes_, offset, dim1, dim2);
+ return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t dim, int64_t index) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::select_backward(grad_input_, input_sizes_, dim, index);
+ return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim = maybe_wrap_dim(dim, logical_rank) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::slice_backward(grad_input_, input_sizes_, dim, start, end, step);
+ return std::make_tuple(std::move(result), 0);
+}
+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -286,6 +327,9 @@
VMAP_SUPPORT("select.int", select_batching_rule);
VMAP_SUPPORT("squeeze", squeeze_batch_rule);
VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule);
+ VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
+ VMAP_SUPPORT("select_backward", select_backward_batch_rule);
+ VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
}
}}
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index c305fcb..95c43da 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -379,28 +379,16 @@
xfail('cummax'),
xfail('cummin'),
xfail('cumprod'),
- xfail('cumulative_trapezoid'),
xfail('diag'),
xfail('diag_embed'),
- xfail('diagonal'),
- xfail('diff'),
- xfail('dsplit'),
xfail('eig'),
- xfail('einsum'),
- xfail('fft.fftn'),
- xfail('fft.hfft'),
- xfail('fft.ifftn'),
xfail('fft.ihfft'),
- xfail('fft.irfft'),
- xfail('fft.irfftn'),
xfail('fft.rfft'),
xfail('fft.rfftn'),
xfail('fill_'),
xfail('float_power'),
xfail('fmax'),
xfail('fmin'),
- xfail('gradient'),
- xfail('hsplit'),
xfail('index_add'),
xfail('index_copy'),
xfail('index_fill'),
@@ -409,7 +397,6 @@
xfail('kthvalue'),
xfail('linalg.cholesky'),
xfail('linalg.cholesky_ex'),
- xfail('linalg.cond'),
xfail('linalg.det'),
xfail('linalg.eig'),
xfail('linalg.eigh'),
@@ -422,7 +409,6 @@
xfail('linalg.pinv', 'hermitian'),
xfail('linalg.slogdet'),
xfail('linalg.solve'),
- xfail('linalg.svd'),
xfail('linalg.tensorinv'),
xfail('linalg.vector_norm'),
xfail('log_softmax'),
@@ -445,7 +431,6 @@
xfail('msort'),
xfail('nanmedian'),
xfail('nanquantile'),
- xfail('narrow'),
xfail('nn.functional.adaptive_avg_pool2d'),
xfail('nn.functional.avg_pool2d'),
xfail('nn.functional.conv_transpose2d'),
@@ -470,22 +455,17 @@
xfail('roll'),
xfail('rot90'),
xfail('scatter_add'),
- xfail('select'),
xfail('solve'),
xfail('sort'),
- xfail('svd'),
xfail('symeig'),
xfail('take'),
xfail('tensor_split'),
xfail('to_sparse'),
xfail('topk'),
xfail('trace'),
- xfail('trapezoid'),
- xfail('trapz'),
xfail('unfold'),
xfail('vdot'),
xfail('view_as_complex'),
- xfail('vsplit'),
})
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
# These are too annoying to put into the list above