blob: 3cf00f33def5570a319ad7f5c0660fb9eeaa37aa [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at::functorch {
// convolution_batch_rule translated from jax with modifications:
// https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143
// PyTorch's convolution is different from JAX's conv_general_dilated:
// we do not support batch_group_count (which is needed for convolution backwards).
// Instead, there's a convolution_backward op that needs a batching rule.
static std::tuple<Tensor, std::optional<int64_t>>
convolution_batch_rule(const Tensor& lhs, std::optional<int64_t> lhs_bdim, const Tensor& rhs, std::optional<int64_t> rhs_bdim, const std::optional<Tensor>& bias, std::optional<int64_t> bias_bdim, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
DimVector lhs_spec(stride.size() + 2);
std::iota(lhs_spec.begin(), lhs_spec.end(), 0);
DimVector rhs_spec = lhs_spec;
DimVector out_spec = lhs_spec;
if (transposed) {
rhs_spec[0] = 1;
rhs_spec[1] = 0;
}
// If we have a batched bias or weight, we need to perform the computation separately.
std::optional<Tensor> unbatched_bias;
bool separate_bias = false;
if ((rhs_bdim && bias && bias->defined()) || bias_bdim) {
TORCH_INTERNAL_ASSERT(bias.has_value());
TORCH_INTERNAL_ASSERT(bias->defined());
unbatched_bias = std::nullopt;
separate_bias = true;
} else {
unbatched_bias = bias;
separate_bias = false;
}
std::tuple<Tensor, std::optional<int64_t>> result;
if (lhs_bdim && !rhs_bdim) {
auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs);
auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
out = reshape_dim_outof_symint(out_spec[0], lhs.sizes()[*lhs_bdim], out);
result = std::make_tuple(out, out_spec[0]);
} else if (!lhs_bdim && rhs_bdim) {
if (groups == 1) {
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs);
auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
out = reshape_dim_outof_symint(out_spec[1], rhs.size(*rhs_bdim), out);
result = std::make_tuple(out, out_spec[1]);
} else {
if (transposed) {
// conv_transpose with groups is normally NIHW, IOHW -> N(GO)HW
// With RHS batched, we do the following:
// NIHW, BIOHW -> NIHW, I(BO)HW -> N(GBO)HW -> BN(GO)HW
// NB: the following isn't written using rhs_spec
// (PyTorch convs have a fixed dimension order)
// BIOHW -> I(BO)HW
auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs);
// NIHW, I(BO)HW -> N(GBO)HW
auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
// N(GBO)HW -> NG(BO)HW
out = reshape_dim_outof_symint(1, groups, out);
// NG(BO)HW -> NGBOHW
out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
// NGBOHW -> NB(GO)HW
out = reshape_dim_into(1, 2, out);
result = std::make_tuple(out, 1);
} else {
// conv with groups is normally N(GI)HW, (GO)IHW -> N(GO)HW
// With RHS batched, we do the following:
// N(GI)HW, B(GO)IHW -> N(GI)HW, (GBO)IHW -> N(GBO)HW -> BN(GO)HW
// NB: the following isn't written using rhs_spec
// (PyTorch convs have a fixed dimension order)
// B(GO)IHW -> BGOIHW
auto new_w = reshape_dim_outof_symint(0 + (*rhs_bdim == 0), groups, rhs);
// BGOIHW -> G(BO)IHW
new_w = reshape_dim_into(*rhs_bdim + (*rhs_bdim > 0), 1, new_w);
// G(BO)IHW -> (GBO)IHW
new_w = reshape_dim_into(0, 0, new_w);
// N(GI)HW, (GBO)IHW -> N(GBO)HW
auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
// N(GBO)HW -> NG(BO)HW
out = reshape_dim_outof_symint(1, groups, out);
// NG(BO)HW -> NGBOHW
out = reshape_dim_outof_symint(2, rhs.size(*rhs_bdim), out);
// NGBOHW -> NB(GO)HW
out = reshape_dim_into(1, 2, out);
result = std::make_tuple(out, 1);
}
}
} else if (lhs_bdim && rhs_bdim) {
auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs);
groups *= lhs.sizes()[*lhs_bdim];
auto dim_with_groups = transposed ? 1 : 0;
auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs);
auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups);
out = reshape_dim_outof_symint(out_spec[1], lhs.sizes()[*lhs_bdim], out);
result = std::make_tuple(out, out_spec[1]);
} else {
result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), std::nullopt);
}
if (separate_bias) {
auto A = std::get<0>(result);
auto A_batch_dim = std::get<1>(result);
auto B = *bias;
auto B_batch_dim = bias_bdim;
A = moveBatchDimToFront(A, A_batch_dim);
B = moveBatchDimToFront(B, B_batch_dim);
for (size_t i = 0; i < out_spec.size() - 2; i++) {
B = B.unsqueeze(-1);
}
B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim));
return std::make_tuple(at::add(A, B), 0);
} else {
return result;
}
}
static Tensor _convolution_decomp(
const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
// Ignore everything. If the user called this in the normal way,
// then they should be fine.
(void) benchmark;
(void) deterministic;
(void) cudnn_enabled;
(void) allow_tf32;
return at::convolution(
input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
}
static Tensor compute_grad_bias(
const Tensor& grad_output_, std::array<bool, 3> output_mask) {
if (!output_mask[2]) {
return Tensor();
}
DimVector reduce_dims;
reduce_dims.resize(grad_output_.dim() - 1);
reduce_dims[0] = 0;
std::iota(reduce_dims.begin() + 1, reduce_dims.end(), 2);
return grad_output_.sum(reduce_dims);
}
// reshapes the batch_size into dim
static Tensor make_dummy(
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
int64_t dim, int64_t batch_size) {
auto tensor_ = tensor_bdim ? tensor.select(*tensor_bdim, 0) : tensor;
auto orig_size = tensor_.size(dim);
tensor_ = tensor_.slice(dim, 0, 1);
DimVector expand_shape(tensor_.sizes().begin(), tensor_.sizes().end());
expand_shape[dim] = batch_size * orig_size;
return tensor_.new_empty({}).expand(expand_shape);
}
static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_input_batch_rule(
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
const Tensor& input, std::optional<int64_t> input_bdim,
const Tensor& weight, std::optional<int64_t> weight_bdim,
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
const std::array<bool, 3> mask = {true, false, false};
if (grad_output_bdim && weight_bdim) {
// regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
// transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
const auto batch_size = weight.size(*weight_bdim);
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups * batch_size, mask);
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 1);
} else if (grad_output_bdim && !weight_bdim) {
// BNO, OI -> (BN)O, OI -> (BN)I
// transposed is the same.
const auto batch_size = grad_output.size(*grad_output_bdim);
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
auto dummy_input = make_dummy(input, input_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 0);
} else if (!grad_output_bdim && weight_bdim) {
const auto batch_size = weight.size(*weight_bdim);
if (groups == 1) {
// regular: NO, BOI -> NO, O(BI) -> N(BI)
// transposed: NO, BIO -> NO, (BI)O -> N(BI)
const auto in_ch_dim = transposed ? 0 : 1;
const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 1);
}
Tensor grad_input;
if (!transposed) {
// N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
grad_input = std::get<0>(result); // N(GBI)
} else {
// N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
weight_ = reshape_dim_outof_symint(1, groups, weight_); // BGIO
weight_ = weight_.transpose(0, 1); // GBIO
weight_ = weight_.flatten(0, 2); // (GBI)O
const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
grad_input = std::get<0>(result); // N(GBI)
}
// N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
grad_input = reshape_dim_outof_symint(1, groups, grad_input);
grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
grad_input = grad_input.transpose(1, 2);
grad_input = reshape_dim_into(2, 2, grad_input);
return std::make_tuple(grad_input, 1);
} else {
TORCH_INTERNAL_ASSERT(input_bdim);
const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
const auto result = at::convolution_backward_symint(
grad_output, dummy_input, weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
return std::make_tuple(std::get<0>(result), std::nullopt);
}
}
static std::tuple<Tensor, std::optional<int64_t>>
convolution_backward_weight_batch_rule(
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
const Tensor& input, std::optional<int64_t> input_bdim,
const Tensor& weight, std::optional<int64_t> weight_bdim,
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
c10::SymIntArrayRef output_padding, const c10::SymInt& groups) {
const std::array<bool, 3> mask = {false, true, false};
if (grad_output_bdim && input_bdim) {
// BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
const auto batch_size = input.size(*input_bdim);
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups * batch_size, mask);
auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
return std::make_tuple(grad_weight, 0);
} else if (grad_output_bdim && !input_bdim) {
const auto batch_size = grad_output.size(*grad_output_bdim);
if (groups == 1) {
// regular: BNO, NI -> N(BO), NI -> (BO)I
// transposed: BNO, NI -> N(BO), NI -> I(BO)
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
const auto out_ch_dim = transposed ? 1 : 0;
const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(out_ch_dim, batch_size, grad_weight);
return std::make_tuple(grad_weight, out_ch_dim);
} else {
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim); // BN(GO)
grad_output_ = reshape_dim_outof_symint(2, groups, grad_output_); // BNGO
grad_output_ = grad_output_.movedim(0, 2); // NGBO
grad_output_ = grad_output_.flatten(1, 3); // N(GBO)
if (!transposed) {
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBOI
grad_weight = grad_weight.transpose(0, 1); // BGOI
grad_weight = grad_weight.flatten(1, 2); // B(GO)I
return std::make_tuple(grad_weight, 0);
} else {
// BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output_, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
return std::make_tuple(grad_weight, 1);
}
}
} else if (!grad_output_bdim && input_bdim) {
const auto batch_size = input.size(*input_bdim);
if (groups == 1) {
// regular: NO, BNI -> NO, N(BI) -> O(BI)
// transposed: NO, BNI -> NO, N(BI) -> (BI)O
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
const auto in_ch_dim = transposed ? 0 : 1;
const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(in_ch_dim, batch_size, grad_weight);
return std::make_tuple(grad_weight, in_ch_dim);
} else {
auto input_ = moveBatchDimToFront(input, input_bdim); // BN(GI)
input_ = reshape_dim_outof_symint(2, groups, input_); // BNGI
input_ = input_.movedim(0, 2); // NGBI
input_ = input_.flatten(1, 3); // N(GBI)
if (!transposed) {
// regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
return std::make_tuple(grad_weight, 1);
} else {
// transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
const auto result = at::convolution_backward_symint(
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
auto grad_weight = std::get<1>(result);
grad_weight = grad_weight.unflatten_symint(0, { groups, batch_size, -1 }); // GBIO
grad_weight = grad_weight.transpose(0, 1); // BGIO
grad_weight = grad_weight.flatten(1, 2); // B(GI)O
return std::make_tuple(grad_weight, 0);
}
}
} else {
TORCH_INTERNAL_ASSERT(weight_bdim);
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1);
const auto result = at::convolution_backward_symint(
grad_output, input, dummy_weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
return std::make_tuple(std::get<1>(result), std::nullopt);
}
}
static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
const auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
int64_t cur_level = maybe_layer->layerId();
if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::convolution_backward_symint(
grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
dilation, transposed, output_padding, groups, output_mask);
}
auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);
const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
output_mask[2] = false;
// TODO: A little bird says that unfold + matmul is actually faster than
// group convolution in many cases. We should benchmark some of
// the common cases and replace things with unfold + matmul as necessary.
// Notation:
// B - a batch dimension
// G - groups (sometimes omitted because it doesn't matter)
// NO - grad_output
// NI - input
// OI - weight
// "(BO)I" - we don't actually care about the values of this Tensor,
// we just need to create a tensor on the same device with the
// correct shape and pray that the implementation is smart enough
// to not do anything with it.
// BNO, BNI, BOI
// AKA one of the model ensembling case
if (grad_output_bdim && input_bdim && weight_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);
// BNO, BNI, BOI -> N(BO), N(BI), (BO)I
const auto batch_size = weight.size(*weight_bdim);
input = reshape_dim_into(*input_bdim, 1, input);
weight = reshape_dim_into(*weight_bdim, 0, weight);
const auto result = at::convolution_backward_symint(
grad_output, input, weight, std::nullopt, stride, padding, dilation,
transposed, output_padding, batch_size * groups, output_mask);
// N(BI), (BO)I -> NBI, BOI
const auto grad_input = output_mask[0] ?
reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
const auto grad_weight = output_mask[1] ?
reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
return std::make_tuple(
output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
grad_bias);
}
Tensor grad_input;
if (output_mask[0]) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = convolution_backward_input_batch_rule(
grad_output, grad_output_bdim,
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
}
Tensor grad_weight;
if (output_mask[1]) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = convolution_backward_weight_batch_rule(
grad_output, grad_output_bdim,
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
// Someone's definitely going to find a problem with this batching rule so
// I'm leaving the following fallback if we need it back.
// static auto op = c10::Dispatcher::singleton()
// .findSchemaOrThrow("aten::convolution_backward", "");
// auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
// grad_output_, input_, weight_, bias_sizes_opt,
// stride, padding, dilation, transposed, output_padding, groups, output_mask
// });
// return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(convolution, convolution_batch_rule);
m.impl("_convolution", _convolution_decomp);
m.impl("convolution_backward", convolution_backward_plumbing);
}
} // namespace at;:functorch