[functorch] some view backward batch rules
diff --git a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
index 6f0f5d8..0d5815b 100644
--- a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
+++ b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
@@ -704,9 +704,6 @@
STOP_DECOMPOSE(masked_select_backward);
STOP_DECOMPOSE(matrix_exp_backward);
STOP_DECOMPOSE(trace_backward);
- STOP_DECOMPOSE(slice_backward);
- STOP_DECOMPOSE(select_backward);
- STOP_DECOMPOSE(diagonal_backward);
STOP_DECOMPOSE(cummaxmin_backward);
STOP_DECOMPOSE(cumprod_backward);
STOP_DECOMPOSE(diag_backward);
diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp
index 342aa45..137ab34 100644
--- a/functorch/functorch/csrc/BatchRulesViews.cpp
+++ b/functorch/functorch/csrc/BatchRulesViews.cpp
@@ -10,6 +10,7 @@
#include <functorch/csrc/PlumbingHelper.h>
#include <functorch/csrc/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/util/SmallBuffer.h>
namespace at { namespace functorch {
@@ -271,6 +272,46 @@
return std::make_tuple(result, 0);
}
+std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
+ dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::diagonal_backward(grad_input_, input_sizes_, offset, dim1, dim2);
+ return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t dim, int64_t index) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::select_backward(grad_input_, input_sizes_, dim, index);
+ return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
+ const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+ IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+ auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+ auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+ dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
+ c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+ input_sizes_[0] = grad_input_.size(0);
+ std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+ auto result = at::slice_backward(grad_input_, input_sizes_, dim, start, end, step);
+ return std::make_tuple(std::move(result), 0);
+}
+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -286,6 +327,9 @@
VMAP_SUPPORT("select.int", select_batching_rule);
VMAP_SUPPORT("squeeze", squeeze_batch_rule);
VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule);
+ VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
+ VMAP_SUPPORT("select_backward", select_backward_batch_rule);
+ VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
}
}}
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index c305fcb..e180572 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -382,7 +382,6 @@
xfail('cumulative_trapezoid'),
xfail('diag'),
xfail('diag_embed'),
- xfail('diagonal'),
xfail('diff'),
xfail('dsplit'),
xfail('eig'),
@@ -470,7 +469,6 @@
xfail('roll'),
xfail('rot90'),
xfail('scatter_add'),
- xfail('select'),
xfail('solve'),
xfail('sort'),
xfail('svd'),