[functorch] Some view backward batch rules
diff --git a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
index 6f0f5d8..0d5815b 100644
--- a/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
+++ b/functorch/functorch/csrc/BatchRulesStopDecomposition.cpp
@@ -704,9 +704,6 @@
   STOP_DECOMPOSE(masked_select_backward);
   STOP_DECOMPOSE(matrix_exp_backward);
   STOP_DECOMPOSE(trace_backward);
-  STOP_DECOMPOSE(slice_backward);
-  STOP_DECOMPOSE(select_backward);
-  STOP_DECOMPOSE(diagonal_backward);
   STOP_DECOMPOSE(cummaxmin_backward);
   STOP_DECOMPOSE(cumprod_backward);
   STOP_DECOMPOSE(diag_backward);
diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp
index 342aa45..6eb1412 100644
--- a/functorch/functorch/csrc/BatchRulesViews.cpp
+++ b/functorch/functorch/csrc/BatchRulesViews.cpp
@@ -10,6 +10,7 @@
 #include <functorch/csrc/PlumbingHelper.h>
 #include <functorch/csrc/BatchedFallback.h>
 #include <ATen/core/dispatch/Dispatcher.h>
+#include <c10/util/SmallBuffer.h>
 
 
 namespace at { namespace functorch {
@@ -271,6 +272,46 @@
   return std::make_tuple(result, 0);
 }
 
+std::tuple<Tensor,optional<int64_t>> diagonal_backward_batch_rule(
+    const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+    IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+  auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+  auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+  dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
+  dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
+  c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+  input_sizes_[0] = grad_input_.size(0);
+  std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+  auto result = at::diagonal_backward(grad_input_, input_sizes_, offset, dim1, dim2);
+  return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
+    const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+    IntArrayRef input_sizes, int64_t dim, int64_t index) {
+  auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+  auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+  dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
+  c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+  input_sizes_[0] = grad_input_.size(0);
+  std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+  auto result = at::select_backward(grad_input_, input_sizes_, dim, index);
+  return std::make_tuple(std::move(result), 0);
+}
+
+std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
+    const Tensor& grad_input, optional<int64_t> grad_input_bdim,
+    IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+  auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
+  auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
+  dim = maybe_wrap_dim(dim, logical_rank) + 1;
+  c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
+  input_sizes_[0] = grad_input_.size(0);
+  std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
+  auto result = at::slice_backward(grad_input_, input_sizes_, dim, start, end, step);
+  return std::make_tuple(std::move(result), 0);
+}
+
 TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
   VMAP_SUPPORT("diag", diag_batch_rule);
   VMAP_SUPPORT("chunk", chunk_batching_rule);
@@ -286,6 +327,9 @@
   VMAP_SUPPORT("select.int", select_batching_rule);
   VMAP_SUPPORT("squeeze", squeeze_batch_rule);
   VMAP_SUPPORT("squeeze.dim", squeeze_dim_batch_rule);
+  VMAP_SUPPORT("diagonal_backward", diagonal_backward_batch_rule);
+  VMAP_SUPPORT("select_backward", select_backward_batch_rule);
+  VMAP_SUPPORT("slice_backward", slice_backward_batch_rule);
 }
 
 }}
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index c305fcb..95c43da 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -379,28 +379,16 @@
         xfail('cummax'),
         xfail('cummin'),
         xfail('cumprod'),
-        xfail('cumulative_trapezoid'),
         xfail('diag'),
         xfail('diag_embed'),
-        xfail('diagonal'),
-        xfail('diff'),
-        xfail('dsplit'),
         xfail('eig'),
-        xfail('einsum'),
-        xfail('fft.fftn'),
-        xfail('fft.hfft'),
-        xfail('fft.ifftn'),
         xfail('fft.ihfft'),
-        xfail('fft.irfft'),
-        xfail('fft.irfftn'),
         xfail('fft.rfft'),
         xfail('fft.rfftn'),
         xfail('fill_'),
         xfail('float_power'),
         xfail('fmax'),
         xfail('fmin'),
-        xfail('gradient'),
-        xfail('hsplit'),
         xfail('index_add'),
         xfail('index_copy'),
         xfail('index_fill'),
@@ -409,7 +397,6 @@
         xfail('kthvalue'),
         xfail('linalg.cholesky'),
         xfail('linalg.cholesky_ex'),
-        xfail('linalg.cond'),
         xfail('linalg.det'),
         xfail('linalg.eig'),
         xfail('linalg.eigh'),
@@ -422,7 +409,6 @@
         xfail('linalg.pinv', 'hermitian'),
         xfail('linalg.slogdet'),
         xfail('linalg.solve'),
-        xfail('linalg.svd'),
         xfail('linalg.tensorinv'),
         xfail('linalg.vector_norm'),
         xfail('log_softmax'),
@@ -445,7 +431,6 @@
         xfail('msort'),
         xfail('nanmedian'),
         xfail('nanquantile'),
-        xfail('narrow'),
         xfail('nn.functional.adaptive_avg_pool2d'),
         xfail('nn.functional.avg_pool2d'),
         xfail('nn.functional.conv_transpose2d'),
@@ -470,22 +455,17 @@
         xfail('roll'),
         xfail('rot90'),
         xfail('scatter_add'),
-        xfail('select'),
         xfail('solve'),
         xfail('sort'),
-        xfail('svd'),
         xfail('symeig'),
         xfail('take'),
         xfail('tensor_split'),
         xfail('to_sparse'),
         xfail('topk'),
         xfail('trace'),
-        xfail('trapezoid'),
-        xfail('trapz'),
         xfail('unfold'),
         xfail('vdot'),
         xfail('view_as_complex'),
-        xfail('vsplit'),
     })
     def test_vmapvjp_has_batch_rule(self, device, dtype, op):
         # These are too annoying to put into the list above