[functorch] Implement per sample grad rule for cudnn_convolution_backward

This is one special case of the batching rule...
diff --git a/functorch/functorch/csrc/BatchRulesConv.cpp b/functorch/functorch/csrc/BatchRulesConv.cpp
index ff1e76e..3ab821e 100644
--- a/functorch/functorch/csrc/BatchRulesConv.cpp
+++ b/functorch/functorch/csrc/BatchRulesConv.cpp
@@ -5,6 +5,8 @@
 // LICENSE file in the root directory of this source tree.
 
 #include <functorch/csrc/BatchRulesHelper.h>
+#include <functorch/csrc/PlumbingHelper.h>
+#include <ATen/core/dispatch/Dispatcher.h>
 
 namespace at { namespace functorch {
 
@@ -84,8 +86,83 @@
   return at::convolution(self, weight, bias, stride, padding, dilation, false, out_padding, groups);
 }
 
+bool first_dim_has_size_1(const Tensor& value, int64_t bdim) {
+  if (bdim == 0) {
+    return value.size(1) == 1;
+  }
+  return value.size(0) == 1;
+}
+
+std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule(
+    const Tensor& self, optional<int64_t> self_bdim,
+    const Tensor& grad_output, optional<int64_t> grad_output_bdim,
+    const Tensor& weight, optional<int64_t> weight_bdim,
+    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark,
+    bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
+  TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim);
+  // TODO: No clue if this works if the first non-batch dim isn't size 1
+  TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim));
+  TORCH_INTERNAL_ASSERT(self.dim() == 5);
+
+  auto bdim_size = self.size(*self_bdim);
+  auto self_ = reshape_dim_into(*self_bdim, 0, self);
+  auto in_channels = self_.size(1);
+  auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
+
+  auto grad_self = at::cudnn_convolution_backward_input(
+      self_.sizes(), grad_output_, weight,
+      padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
+  grad_self = reshape_dim_outof(0, bdim_size, grad_self);
+
+  // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py
+  auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride);
+  auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)});
+  auto grad_sample = at::einsum("noq,npq->nop", {B, A});
+  grad_sample = grad_sample.view({
+      bdim_size, groups, -1, groups, in_channels / groups,
+      weight.size(2) * weight.size(3) });
+  grad_sample = at::einsum("ngrg...->ngr...", {grad_sample});
+  grad_sample = grad_sample.reshape(
+      {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)});
+
+  return std::make_tuple(grad_self, 0, grad_sample, 0);
+}
+
+std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
+  int64_t cur_level = maybe_layer->layerId();
+
+  Tensor self_value;
+  optional<int64_t> self_bdim;
+  std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
+  Tensor grad_output_value;
+  optional<int64_t> grad_output_bdim;
+  std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
+  Tensor weight_value;
+  optional<int64_t> weight_bdim;
+  std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
+
+  if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) {
+    auto result = cudnn_conv_per_sample_grad_rule(
+        self_value, self_bdim, 
+        grad_output_value, grad_output_bdim,
+        weight_value, weight_bdim,
+        padding, stride, dilation, groups,
+        benchmark, deterministic, allow_tf32, output_mask);
+    return std::make_tuple(
+        makeBatched(std::get<0>(result), std::get<1>(result), cur_level),
+        makeBatched(std::get<2>(result), std::get<3>(result), cur_level));
+  }
+
+  static auto op = c10::Dispatcher::singleton()
+    .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
+  return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
+}
+
 TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
   VMAP_SUPPORT("conv2d", conv2d_batching_rule);
   m.impl("mkldnn_convolution", mkldnn_convolution_decomp);
+  m.impl("cudnn_convolution_backward", cudnn_convolution_backward_plumbing);
 }
 }}