[functorch] Implement per sample grad rule for cudnn_convolution_backward
This is one special case of the batching rule...
diff --git a/functorch/functorch/csrc/BatchRulesConv.cpp b/functorch/functorch/csrc/BatchRulesConv.cpp
index ff1e76e..3ab821e 100644
--- a/functorch/functorch/csrc/BatchRulesConv.cpp
+++ b/functorch/functorch/csrc/BatchRulesConv.cpp
@@ -5,6 +5,8 @@
// LICENSE file in the root directory of this source tree.
#include <functorch/csrc/BatchRulesHelper.h>
+#include <functorch/csrc/PlumbingHelper.h>
+#include <ATen/core/dispatch/Dispatcher.h>
namespace at { namespace functorch {
@@ -84,8 +86,83 @@
return at::convolution(self, weight, bias, stride, padding, dilation, false, out_padding, groups);
}
+bool first_dim_has_size_1(const Tensor& value, int64_t bdim) {
+ if (bdim == 0) {
+ return value.size(1) == 1;
+ }
+ return value.size(0) == 1;
+}
+
+std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule(
+ const Tensor& self, optional<int64_t> self_bdim,
+ const Tensor& grad_output, optional<int64_t> grad_output_bdim,
+ const Tensor& weight, optional<int64_t> weight_bdim,
+ IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark,
+ bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
+ TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim);
+ // TODO: No clue if this works if the first non-batch dim isn't size 1
+ TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim));
+ TORCH_INTERNAL_ASSERT(self.dim() == 5);
+
+ auto bdim_size = self.size(*self_bdim);
+ auto self_ = reshape_dim_into(*self_bdim, 0, self);
+ auto in_channels = self_.size(1);
+ auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
+
+ auto grad_self = at::cudnn_convolution_backward_input(
+ self_.sizes(), grad_output_, weight,
+ padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
+ grad_self = reshape_dim_outof(0, bdim_size, grad_self);
+
+ // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py
+ auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride);
+ auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)});
+ auto grad_sample = at::einsum("noq,npq->nop", {B, A});
+ grad_sample = grad_sample.view({
+ bdim_size, groups, -1, groups, in_channels / groups,
+ weight.size(2) * weight.size(3) });
+ grad_sample = at::einsum("ngrg...->ngr...", {grad_sample});
+ grad_sample = grad_sample.reshape(
+ {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)});
+
+ return std::make_tuple(grad_self, 0, grad_sample, 0);
+}
+
+std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
+ auto maybe_layer = maybeCurrentDynamicLayer();
+ TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
+ int64_t cur_level = maybe_layer->layerId();
+
+ Tensor self_value;
+ optional<int64_t> self_bdim;
+ std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
+ Tensor grad_output_value;
+ optional<int64_t> grad_output_bdim;
+ std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
+ Tensor weight_value;
+ optional<int64_t> weight_bdim;
+ std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
+
+ if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) {
+ auto result = cudnn_conv_per_sample_grad_rule(
+ self_value, self_bdim,
+ grad_output_value, grad_output_bdim,
+ weight_value, weight_bdim,
+ padding, stride, dilation, groups,
+ benchmark, deterministic, allow_tf32, output_mask);
+ return std::make_tuple(
+ makeBatched(std::get<0>(result), std::get<1>(result), cur_level),
+ makeBatched(std::get<2>(result), std::get<3>(result), cur_level));
+ }
+
+ static auto op = c10::Dispatcher::singleton()
+ .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
+ return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
+}
+
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("conv2d", conv2d_batching_rule);
m.impl("mkldnn_convolution", mkldnn_convolution_decomp);
+ m.impl("cudnn_convolution_backward", cudnn_convolution_backward_plumbing);
}
}}