blob: 683ed76829188106f7ad8d876e8040773d8ad5bd [file] [log] [blame]
// NB: Must be at the top of file to avoid including the deprecated "math.h".
// https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
#ifdef _MSC_VER
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <cmath>
#endif
#include "torch/csrc/autograd/generated/Functions.h"
#include <ATen/Utils.h>
#include <c10/core/TensorOptions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/core/Reduction.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarOps.h>
#include <ciso646>
#include <algorithm>
#include <numeric>
#include <functional>
// ${generated_comment}
using at::Tensor;
using at::Scalar;
using at::IntArrayRef;
using at::TensorList;
namespace torch { namespace autograd { namespace generated {
namespace {
bool isDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
Tensor toLegacyTensor(const c10::optional<Tensor>& t) {
return t.has_value() ? *t : Tensor();
}
// Helper functions for autogenerated code
// A simple way to imperatively compute index ranges for slots
// that have been flattened
struct IndexRangeGenerator {
IndexRange range(size_t range_size) {
i += range_size;
return {i - range_size, i};
}
size_t size() { return i; }
private:
size_t i = 0;
};
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
out[range.first] = t;
}
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
std::copy(t.begin(), t.end(), out.begin() + range.first);
}
Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
}
Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
bool is_one = false;
if (s.isFloatingPoint()) {
is_one = s.toDouble() == 1;
} else if(s.isIntegral(true)) {
is_one = s.toLong() == 1;
}
if (is_one) {
return t;
} else {
return t * s;
}
}
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
int64_t size = 1;
if (sizes.size() == 0) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, sizes.size());
size *= sizes[d];
}
return size;
}
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
}
// handle case at 0 where we return a subgradient containing 0
Tensor ratio = grad / res;
ratio.masked_fill_(res == 0, 0);
return std::tuple<Tensor, Tensor>{
x1 * ratio.sum(-1, true) - ratio.matmul(x2),
x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)};
}
Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;
Tensor scale_v;
if (p == 0.0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else if (p == 1.0) {
return self.sign() * grad;
} else if (p == 2.0) {
self_scaled = self;
scale_v = grad / norm;
} else if (std::isinf(p)) {
self_scaled = self.sign() * (self.abs() == norm).type_as(self);
scale_v = grad.clone(at::MemoryFormat::Preserve);
} else if (p < 2.0) {
self_scaled = self.sign() * self.abs().pow(p - 1);
scale_v = grad / norm.pow(p - 1);
} else {
self_scaled = self * self.abs().pow(p - 2);
scale_v = grad / norm.pow(p - 1);
}
// handle case at 0 where we return a subgradient containing 0
scale_v.masked_fill_(norm == 0, 0);
return self_scaled * scale_v;
}
Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, IntArrayRef dim, bool keepdim) {
IntArrayRef sizes = self.sizes();
if (!keepdim && self.dim() != 0) {
if (dim.size()==1) {
grad = grad.unsqueeze(dim[0]);
norm = norm.unsqueeze(dim[0]);
} else {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, sizes.size());
for (size_t i = 0; i < sizes.size(); i++){
if (dims_to_unsqueeze[i]) {
grad = grad.unsqueeze(i);
norm = norm.unsqueeze(i);
}
}
}
}
return norm_backward(grad, self, p_, norm);
}
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
double exponent = exponent_.toDouble();
if (exponent == 0.0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return grad * exponent * self.pow(exponent - 1);
}
}
Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1));
}
// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
return grad * at::where(at::logical_and(self == 0, exponent >= 0),
at::zeros({}, grad.options()),
result * self.log());
}
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
if (base.toDouble() == 0) {
return grad * at::where(exponent >= 0,
at::zeros({}, grad.options()),
result * std::log(base.toDouble()));
} else {
return grad * result * std::log(base.toDouble());
}
}
Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
args = args.add(self.unsqueeze(-1));
return grad * args.digamma_().sum(-1);
}
Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (size_t i = 0; i < ndims; i++) {
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
}
return grad.permute(dims);
}
Tensor rad2deg_backward(const Tensor& grad) {
constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
return at::mul(grad, wrapped_scalar_tensor(Scalar(M_180_PI)));
}
Tensor deg2rad_backward(const Tensor& grad) {
constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
return at::mul(grad, wrapped_scalar_tensor(Scalar(M_PI_180)));
}
Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
Tensor res = t;
for (size_t i = 0; i < n_dims; i++){
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
}
return res;
}
Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes);
} else {
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes);
}
} else {
return grad.expand(sizes);
}
}
Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) {
auto sizes = self.sizes();
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not();
} else {
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes) * self.isnan().logical_not();
}
} else {
return grad.expand(sizes) * self.isnan().logical_not();
}
}
std::vector<int64_t> reverse_list(const IntArrayRef list) {
auto result = std::vector<int64_t>();
result.reserve(list.size());
for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
result.push_back(*iter);
}
return result;
}
Tensor reverse_dim(const Tensor& t, int64_t dim) {
Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
return t.index_select(dim, index);
}
Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
if (inp.size(dim) == 1) {
return grad;
}
auto ones_size = inp.sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones(ones_size, grad.options());
Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
return grad * (exclusive_normal * exclusive_reverse);
}
// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input: [ a, b, c]
// cumprod(exclusive, normal): [1 , a, a * b]
// cumprod(exclusive, reverse): [b * c, c, 1]
// product: [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
if (input.dim() == 0) {
return grad;
}
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.numel() == 0) {
return (grad * result) / input;
} else if (zero_idx.size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
}
}
Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
if (!keepdim && input.dim() != 1) {
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
}
Tensor zero_mask = (input == 0);
Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) {
return (grad * result) / input;
} else {
return prod_safe_zeros_backward(grad, input, dim);
}
}
Tensor sum_scan_exclusive(const Tensor& x, int64_t dim) {
Tensor ret = at::cumsum(-x, dim);
int64_t end_idx = ret.size(dim) - 1;
Tensor ret_sum = ret.narrow(dim, end_idx, 1).clone(at::MemoryFormat::Preserve);
ret -= ret_sum.expand_as(ret);
ret += x;
return ret;
}
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim) {
/*
There are two algorithms to do this. The first one
is very efficient, but works only when there are no
nonzero elements in the input.
The second one is much more complex, but it doesn't
assume anything on the input. The main downside is
that it takes time O(n^2), where n = input.size(self.dim)
(i.e. the length of the cumulative product). This is in
contrast to the forward pass and the efficient algorithm,
which are both O(n).
The second algorithm is a simple application of the chain
rule. If x is an n-dimensional vector, and y = cumprod(x),
and F is the final cost, then
dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1)
The term dF / dy_j is just grad_output[j] (assuming again
everything is one-dimensional).
The term (dy_j / dx_k) is easilly seen to be
if j >= k
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
else:
dy_j / dx_k = 0
Note that the indicator (j>=k) can be taken out
by replacing the sum in (1) with a sum from
j = k to n.
Thus,
df / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
with
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2)
Note that this last term is just the cumulative product
with k omitted. Thus, if x_k (the input) is nonzero, we can
just express this as
dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
= y_j / x_k
So therefore,
df / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
so
grad_output = sum_scan_exclusiv(grad_output * output) / input
If the input is nonzero, we need to calculate the dy_j / dx_k
by using the formula (2), called in the code omitted_products.
The way the code calculates it is simply by noting that
prod_{1 <= i <= j, i != k} x_i
= (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
the first term is calculated as prods_until_k, which since
doesn't depend in j is easy to vectorize.
The second term (indexed by j) is the cumulative product of
x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
prods_from_k_pkus_1, and it's calculated as a cumprod.
In order to vectorize this properly, we need to add to
omitted_products the dimensions where k > j, and therefore
dy_j / dx_k = 0, which is done right after the assert.
*/
if (input.dim() == 0 || input.numel() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
int64_t dim_size = input.size(dim);
if (dim_size == 1) {
return grad;
}
// Simple case with nonzero elements in the input
if ((input != 0).all().item<uint8_t>()) {
Tensor result = at::cumprod(input, dim);
return sum_scan_exclusive(result * grad, dim) / input;
}
auto ones_size = input.sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones({1}, grad.options()).expand(ones_size);
Tensor grad_input = at::zeros(input.sizes(), grad.options());
Tensor prods_from_k_plus_1;
Tensor omitted_products;
for (int k = 0; k < dim_size; ++k) {
if (k == 0) {
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim);
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
} else if (k == dim_size - 1) {
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
omitted_products = prods_until_k;
} else {
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim);
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
}
// At this point omitted_products is the same size
// as input, except on the dimension dim where it's
// dim_size - k
AT_ASSERT(omitted_products.size(dim) == dim_size - k);
grad_input.select(dim, k).copy_(
at::sum(grad.slice(dim, k) * omitted_products,dim));
}
return grad_input;
}
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim, optional<ScalarType> dtype) {
return cumprod_backward(grad.to(input.scalar_type()), input, dim);
}
Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
return std::get<0>(at::solve(grad, A.transpose(-2, -1)));
}
Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
Tensor grad_self = solve_backward_self(grad, self, A);
if (self.ndimension() == 2 && A.ndimension() == 2) {
return -at::mm(grad_self, solution.transpose(-2, -1));
}
return -at::matmul(grad_self, solution.transpose(-2, -1));
}
Tensor cumsum_backward(const Tensor & x, int64_t dim) {
// Need to check numel to see if there are no values (such as shape [0,2], and dim to see if x is a scalar.
if (x.dim() == 0 || x.numel() == 0) {
return x;
}
auto ret = at::cumsum(-x, dim);
auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(at::MemoryFormat::Preserve);
ret -= ret_sum.expand(ret.sizes());
ret += x;
return ret;
}
Tensor cummax_backward(const Tensor &indices, const Tensor &grad, const Tensor &input, int64_t dim) {
if (input.numel() == 0) {
return input;
}
auto result = at::zeros(input.sizes(), input.options());
return result.scatter_add_(dim, indices, grad);
}
Tensor cummin_backward(const Tensor &indices, const Tensor &grad, const Tensor &input, int64_t dim) {
if (input.numel() == 0) {
return input;
}
auto result = at::zeros(input.sizes(), input.options());
return result.scatter_add_(dim, indices, grad);
}
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
result = unsqueeze_multiple(result, dim, self.sizes().size());
}
return grad * (self - result).exp();
}
Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) {
if (grad.dim() == 0 || grad.numel() == 0) {
return grad;
}
// Reference: https://github.com/tensorflow/tensorflow/blob/
// 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
return AT_DISPATCH_FLOATING_TYPES(
at::typeMetaToScalarType(grad.dtype()),
"logcumsumexp_backward",
[grad, self, result, dim]() {
auto grad_min = at::empty_like(grad);
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
auto reverse_logcumsumexp = [dim](auto x) {
return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
};
auto output_pos =
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
auto output_neg =
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
return output_pos - output_neg;
});
}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
IntArrayRef sizes;
at::TensorOptions o;
for (auto v : grads) {
if (v.defined()) {
sizes = v.sizes();
o = static_cast<Tensor>(v).options();
break;
}
}
auto grads_tensors = fmap(grads, [&](const Variable& v) {
return (
v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
});
return at::stack(grads_tensors, dim);
}
Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
auto result = self;
int64_t nDims = sizes.size();
for (int64_t dim = 0; dim < nDims; dim++) {
if (sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
// unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
}
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
std::vector<Tensor> grad_inputs(sizes.size());
if (!grad.defined()) {
return grad_inputs;
}
dim = at::legacy_cat_wrap_dim(dim, sizes);
int64_t accumulate = 0;
for (size_t i = 0; i < sizes.size(); ++i) {
auto& shape = sizes[i];
// If input was empty tensor, gradInput should be empty tensor.
if (shape == std::vector<int64_t>({0})) {
grad_inputs[i] = at::zeros({0}, grad.options());
continue;
}
auto size = shape[dim];
accumulate += size;
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
}
return grad_inputs;
}
Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
// clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
if (max && min) {
return grad * ((self >= *min) * (self <= *max)).type_as(grad);
} else if (min) {
return grad * (self >= *min).type_as(grad);
} else if (max) {
return grad * (self <= *max).type_as(grad);
} else {
return grad;
}
}
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor & mat1, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (mat1.is_sparse()) {
throw std::runtime_error("calculating the gradient of a sparse Tensor argument to mm is not supported.");
}
at::IntArrayRef sizes = mat1.sizes();
at::IntArrayRef strides = mat1.strides();
if (strides[0] == 1 && strides[1] == sizes[0]) {
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
} else {
return maybe_multiply(grad.mm(mat2.t()), alpha);
}
}
Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (strides[0] == 1 && strides[1] == sizes[0]) {
if (mat1.is_sparse()) {
// Since mm(dense, sparse) doesn't exist,
// pass a transposed output matrix to the underlying "addmm"
// function directly.
int64_t out_rows = mat1.size(1);
int64_t out_cols = grad.size(1);
Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true);
Tensor r = at::empty({out_cols, out_rows}, grad.options()).t();
at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
return r;
}
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
} else {
return maybe_multiply(mat1.t().mm(grad), alpha);
}
}
Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) {
AT_ASSERT(sparse_.is_sparse());
auto sparse = sparse_.coalesce();
Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha);
return grad_sparse.sparse_mask(sparse);
}
Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
auto flatten = [&](const Tensor & t) {
return t.transpose(dim, 0).contiguous().view({t.size(dim), -1});
};
auto unflatten = [&](const Tensor & t) {
return t.contiguous().view(transposed_sizes).transpose(dim, 0);
};
// renorm computes the norm over all dimensions except `dim`, which is why
// we need the flatten and unflatten business. TODO: simplify this when we
// add support for norm over multiple dimensions.
auto self_flat = flatten(self);
auto grad_flat = flatten(grad);
auto norm_flat = self_flat.norm(p, 1, true);
auto grad_output = (self_flat * grad_flat).sum(1, true);
auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true);
auto invnorm = (norm_flat + 1e-7).reciprocal();
auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb));
auto norm = unflatten(norm_flat.expand_as(self_flat));
// TODO: remove the detach once comparison ops no longer require grad
auto mask = Variable(norm < maxnorm).detach();
return at::where(mask, grad, grad_norm);
}
Tensor sum_tensorlist(TensorList tl) {
if (tl.size() == 0) {
throw std::runtime_error("Can't sum tensorlist of size 0");
}
Tensor sum = tl[0];
for(size_t i = 1; i < tl.size(); ++i) {
sum = sum + tl[i];
}
return sum;
}
Tensor repeat_backward(Tensor grad, int64_t input_dims, IntArrayRef repeats) {
int64_t num_unsqueezed = grad.dim() - input_dims;
for (int64_t i = 0; i < num_unsqueezed; ++i) {
grad = grad.sum(0, false);
}
for (size_t j = num_unsqueezed; j < repeats.size(); ++j) {
int64_t repeat = repeats[j];
if (repeat == 1) {
continue;
}
int64_t dim = j - num_unsqueezed;
grad = sum_tensorlist(grad.chunk(repeat, dim));
}
return grad;
}
// p1m == 1 - p
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
if (grad.requires_grad()) {
// Use autograd-friendly backward if double backward is required
return grad * (mask.type_as(grad) * (1. / p1m));
} else {
return at::_masked_scale(grad, mask, 1. / p1m);
}
}
Tensor select_first_equal_backward(Tensor grad, const Tensor & input, const Tensor & value) {
auto grad_input = at::zeros_like(input);
// find indices of the first element for which input[idx] == value
auto first_value_idx = (input == value).nonzero().select(0, 0);
if (grad_input.dim() == 0) {
grad_input.copy_(grad);
}
else {
grad_input.index_put_(at::chunk(first_value_idx, grad_input.dim()), grad);
}
return grad_input;
}
Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
grad = grad.unsqueeze(dim);
indices = indices.unsqueeze(dim);
}
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
}
Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.slice(dim, start, end, step).copy_(grad);
return grad_input;
}
Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.select(dim, index).copy_(grad);
return grad_input;
}
Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");
}
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
grad_input.index_fill_(0, indices, grad);
return grad_input.view(sizes);
}
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
}
Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
if (self.dim() == 0) {
return var_backward(grad, self, unbiased);
}
if (!keepdim && self.dim() > 1) {
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
}
return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
}
Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) {
return var_backward(grad / (result * 2), self, unbiased);
}
Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
return var_backward(grad / (result * 2), self, dim, unbiased, keepdim);
}
Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) {
return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim);
}
Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int numel) {
return grad.expand(sizes) / numel;
}
Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, IntArrayRef dim, bool unbiased, bool keepdim, bool is_std) {
Tensor grad;
if (grads[0].defined()) {
grad = is_std ? std_backward(r1, grads[0], self, dim, unbiased, keepdim) : var_backward(grads[0], self, dim, unbiased, keepdim);
}
if (grads[1].defined()) {
Tensor mean_grad = mean_backward(grads[1], self.sizes(), dim, keepdim);
grad = grads[0].defined() ? grad + mean_grad : mean_grad;
}
return grad;
}
Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, bool unbiased, bool is_std) {
Tensor grad;
if (grads[0].defined()) {
grad = is_std ? std_backward(r1, grads[0], self, unbiased) : var_backward(grads[0], self, unbiased);
}
if (grads[1].defined()) {
Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel());
grad = grads[0].defined() ? grad + mean_grad : mean_grad;
}
return grad;
}
Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
int64_t numel = 1;
for (auto size : sizes) {
numel *= size;
}
auto mask_selected = grad.masked_select(mask);
auto diff_nelem = numel - mask_selected.numel();
if (diff_nelem > 0) {
// because mask_selected returns a 1-d tensor with size of masked elements that are 1,
// we need to fill out the rest with zeros then reshape back to tensor2's size.
auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
}
return mask_selected.view(sizes);
}
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// cf. Iain Murray (2016); arXiv 1602.07527
// This gradient is symmetric, and not triangular.
// Cholesky additionally assumes that the input is symmetric, which is a subspace of
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
// elements. We resolve this by taking the gradient of the functionally independent
// elements of the matrix (i.e., the lower triangular portion of the input) and then
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
// the cholesky operation. The motivation behind this choice is that symmetric gradient
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
if (upper) {
L = L.transpose(-1, -2);
grad = grad.transpose(-1, -2);
}
auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
auto phi = at::matmul(L.transpose(-1, -2), grad);
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2), phi), L_inverse);
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient
}
Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
Tensor grad_L;
if (grad.defined()) {
Tensor common_term = grad + grad.transpose(-2, -1);
common_term = at::matmul(inverse, at::matmul(common_term, inverse));
if (upper) {
grad_L = -at::matmul(L, common_term);
} else {
grad_L = -at::matmul(common_term, L);
}
} else {
grad_L = at::zeros({1}, L.options()).expand_as(L);
}
return grad_L;
}
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// it's possible some of the grads are not defined (represents tensors of all 0s).
// Since at::cat can't handle those, let's define them
std::vector<Tensor> grads_all_defined(grads.size());
for (size_t j = 0; j < grads.size(); ++j) {
if (grads[j].defined()) {
grads_all_defined[j] = grads[j];
} else {
auto length = split_sizes[j];
auto grad_size = sizes.vec();
grad_size[dim] = length;
grads_all_defined[j] = at::zeros(grad_size, options);
}
}
auto ret = at::cat(grads_all_defined, dim);
return ret;
}
Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
dim = at::maybe_wrap_dim(dim, sizes.size());
int64_t dim_size = sizes[dim];
int64_t num_splits = grads.size();
std::vector<int64_t> split_sizes(num_splits, split_size);
split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
return split_with_sizes_backward(grads, split_sizes, dim, sizes, options);
}
Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
AT_ASSERT(indices.dim() >= dim);
auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
size.push_back(-1);
auto indices_view = indices.view(size);
const auto memory_format = indices.suggest_memory_format();
return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes());
}
Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
auto& gO = grad_output;
auto input_size = input.size(dim) / 2;
auto first_half = input.narrow(dim, 0, input_size);
auto second_half = input.narrow(dim, input_size, input_size);
auto sig_second_half = second_half.sigmoid();
auto one_sub_sig_second_half = 1 - sig_second_half;
auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
auto ggI_first_half = grad.narrow(dim, 0, input_size);
auto ggI_second_half = grad.narrow(dim, input_size, input_size);
auto ggI_second_half_times_first_half = ggI_second_half * first_half;
auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
return at::cat({gI_first_half, gI_second_half}, dim);
}
Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
if (dim < 0) dim += input.dim();
auto sizes = input.sizes().vec();
sizes[dim] /= 2;
auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
}
Tensor infinitely_differentiable_gelu_backward(
const Tensor& grad,
const Tensor& self) {
constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
Tensor cdf = (1.0 + (self * M_SQRT1_2).erf_()).mul_(0.5);
Tensor pdf = (-0.5 * self * self).exp_();
return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
}
Tensor infinitely_differentiable_silu_backward(
const Tensor& grad_output,
const Tensor& input) {
const Tensor sigmoid = input.sigmoid();
return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
}
Tensor infinitely_differentiable_logit_backward(
const Tensor& grad,
const Tensor& self,
c10::optional<double> eps) {
if (eps) {
const double lo = eps.value();
const double hi = 1.0 - lo;
return at::where(
at::logical_and(self >= lo, self <= hi),
grad / (self * (1.0 - self)),
at::zeros({}, self.options()));
} else {
return at::where(
at::logical_and(self >= 0.0, self <= 1.0),
grad / (self * (1.0 - self)),
at::empty({}, self.options())
.fill_(std::numeric_limits<double>::quiet_NaN()));
}
}
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
if (reduction == at::Reduction::Mean) {
return result.mean();
} else if (reduction == at::Reduction::Sum) {
return result.sum();
}
return result;
}
// Compute derivatives for targets.
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
Tensor grad_target;
if (!log_target) {
grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
}
else {
grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
}
if (reduction == at::Reduction::Mean) {
grad_target.div_(target.numel());
}
return grad_target;
}
Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional<Tensor>& weight, const c10::optional<Tensor>& pos_weight, int64_t reduction) {
Tensor grad_target;
if (isDefined(pos_weight)) {
grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output);
} else {
grad_target = self.mul(-grad_output);
}
if (isDefined(weight)) {
grad_target.mul_(*weight);
}
if (reduction == at::Reduction::Mean) {
grad_target.div_(target.numel());
}
return grad_target;
}
Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
auto z = input.sigmoid();
return grad * (z - 1) * z;
}
Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto gO = grad_output;
auto ggI = grad;
auto ggI_output = ggI * output;
auto ggI_out_sum = ggI_output.sum(dim, true);
auto ggI_out_sum_output = ggI_out_sum * output;
auto gO_out_sum = (gO * output).sum(dim, true);
// gI calculation
auto gI_t0 = ggI_output * (gO - gO_out_sum);
auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
auto gI_t2 = ggI_out_sum_output * gO;
auto gI_t3 = ggI_out_sum_output * gO_out_sum;
return gI_t0 - gI_t1 - gI_t2 + gI_t3;
}
Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto z = output.exp();
return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
}
Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
auto eps = 1e-12;
auto inp_pl_eps = input + eps;
auto one_m_inp_pl_eps = 1 - input + eps;
// gradient wrt input
auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
gI *= (grad * grad_output);
if (isDefined(weight)) {
gI *= *weight;
}
if (reduction == at::Reduction::Mean) {
return gI / input.numel();
} else if (reduction == at::Reduction::Sum) {
return gI.sum();
}
return gI;
}
Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
auto eps = 1e-12;
// gradient wrt grad_output
auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
ggO *= grad;
if (isDefined(weight)) {
ggO *= *weight;
}
if (reduction == at::Reduction::Mean) {
return ggO / input.numel();
} else if (reduction == at::Reduction::Sum) {
return ggO.sum();
}
return ggO;
}
Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto output = l1_loss_backward(grad, input, target, at::Reduction::None);
if (reduction == at::Reduction::Mean) {
return output.mean();
} else if (reduction == at::Reduction::Sum) {
return output.sum();
}
return output;
}
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto d = (input - target).abs();
auto grad_input = grad * (d < 1).type_as(grad);
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == at::Reduction::None) {
return smooth_l1_loss_backward(grad, input, target, reduction);
}
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor diag_backward(const Tensor & grad, IntArrayRef input_sizes, int64_t diagonal) {
auto ndimension = input_sizes.size();
AT_ASSERT(ndimension == 1 || ndimension == 2);
if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
return grad.diag(diagonal);
}
// Input was a matrix but was not square
auto grad_input = at::zeros(input_sizes, grad.options());
auto diag = grad_input.diagonal(diagonal);
diag.copy_(grad);
return grad_input;
}
Tensor diagonal_backward(const Tensor & grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
auto grad_input = at::zeros(input_sizes, grad.options());
auto diag = grad_input.diagonal(offset, dim1, dim2);
diag.copy_(grad);
return grad_input;
}
Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
auto grad_input = 2 * grad;
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == at::Reduction::None) {
return mse_loss_backward(grad, input, target, reduction);
}
auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto z = (input * -target).exp();
auto zplus1 = z + 1;
auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == at::Reduction::None) {
return soft_margin_loss_backward(grad, input, target, reduction);
}
auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) {
auto x = (input * beta);
return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta;
}
// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
// x = torch.randn(15)
// x.as_strided([3, 3], [1, 0]) # "expand" case
// x.as_strided([3, 3], [2, 1]) # "size too large" case
// x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6
// # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
// 0. ??? (optimization step. we will talk about this later)
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor.
// 3. ??? (fix for input tensor with overlapping memory. we will talk about
// this later)
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor does't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
// x = torch.randn(3)
// y = x.as_strided([3, 3], [1, 0]) # "expand" case
// # size [ 3, 3]
// # stride [ 1, 0]
// y.backward() # step (1): contiguous storagte tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): since `y` has overlapping memory, index_add grad
// into `s` basing on `y`'s geometry, i.e.,
// s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
// s = [ 3, 3, 3]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 3, 3, 3]
// grad_input = s.as_strided(x.size(), x.stride())
// = s.as_strided([3], [1])
// = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
// t = torch.randn(3)
// x = t.expand(3, 3) # input with overlapping memory
// # size [3, 3]
// # stride [0, 1]
// y = x.as_strided([1], [1]) # contiguous output
// # size [1]
// # stride [1]
// y.backward() # step (1): contiguous storage tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): scatter grad into `s` basing on `y`'s geometry
// s = [ 1, 0, 0]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 1, 0, 0]
// grad_input = s.as_strided([3, 3], [0, 1])
// = s.as_strided([3, 3], [0, 1])
// = [[ 1, 0, 0],
// [ 1, 0, 0],
// [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
// t = torch.randn(1)
// x = t.expand(2)
// y = x.sum()
// y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
// gy = 1
// gx = [ 1, 1] # SumBackward: torch.ones_like(x)
// gt = [ 2] # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
// gy = 1
// gx = [ 2, 2] # Because the backward considers the fact that the input `x`
// # is already expanded.
// gt = [ 2] # Layout-aware backward of expand is just a slicing because
// # the previous backward should have already taken care of
// # strides and made sure that gradients are the same along the
// # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
// 1. storing tensor geometries of every input tensor for backward
// 2. depending on input geometry, the gradient computed from backward change
// 3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
// 1. "Scatter" each value of `input` into a "storage" using storage location
// computed from the value's index in `input`, `input.size()` and
// `input_stride`, but if N values end up in the same location, the value
// is average of those N values (they will be the same value anyways).
//
// Formal description:
// Denote the set of all input indices that pointing to the same storage
// location `storage[n]` as `S(n)`, i.e.,
//
// S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
//
// where `<x, y>` is the dot product between `x` and `y`.
//
// Then, the process is:
//
// storage[n] = Avg { S(n) }
//
// Note that all values in `S(n)` are the same (they point to the same
// memory location anyways, so this step doesn't change anything, but
// effectively avoids having the denpendency on the layout of `input`.
// I.e., the result holds fixed regardless of the layout of `input`, as
// long as `input_stride` is fixed.
//
// NOTE: for forward pass, we can equivalently simply selet any one of
// `S(n)` as `storage[n]`. However, cosnidering this as an average
// operation makes backward easier (so all values in set
// `{ grad_input[i] : i in S(n) }` are the same, and it can use the
// same geometry as input).
// 2. As usual, return the as_strided view of `storage` using required output
// `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
// .... (scatter gradients into the storage tensor using output geometry)
// 3. For all storage location n, `storage[n] /= |S(n)|`.
// .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
// Add step (0): For all output dimension `d` with output stride 0, sum the
// gradients along dimension `d` (don't keepdim), and remove
// dimension `d` from output size and stride.
// (An optimization for "expand" cases so we may avoid step (3))
// Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
// 0. For all output dimension `d` with output stride 0, sum the gradients
// along dimension `d` (don't keepdim), and remove dimension `d` from
// output size and stride.
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor `storage`.
// 3. If input tensor has overlapping memory,
// For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
// number of indices in input geometry pointing to the same storage
// location `i` (i.e., `|S(i)|` in equations above).
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
// 0. Return [ pass check ] if any dimension has size 0
// 1. Ignore all dimensions that have size 1
// 2. If no remaining dimensions, return [ pass check ]
// 3. Sort the remaining dimensions according to the strides decreasingly
// 4. Check that for each dimension k,
//
// stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
// That is equivalent to, after reordering the dimensions so strides are
// in decreasing order, checking that stride of each dimension is larger
// than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
// alias: Obviously preserves
//
// expand: All changed dimensions are removed in step (1)
//
// view: Consider the input dimensions as grouped into consecutive
// dimension "blocks", where dimensions are contiguous in each one.
// one. view only works when the output dimensions can also be
// grouped into the same consecutive blocks of same ordering.
//
// NB: this means that the number of elements and stride of the
// last dimension in each block is the same in input and
// output. (**)
//
// Notation:
// Consider a single such block B,
// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
// start--^^^^ ^^^^^^^^^^^^--end
// Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
// We first show that in a tensor (i.e., input) satisfies the
// invariant, after sorting, the dimensions within each block
// still remain consecutive. (***)
//
// After removing dimensions of size 1, the dimensions within a
// block is already sorted by strides in descending order. So
// sorting all dimensions will not change the relative ordering
// among them.
//
// Assume that some block B is not consecutive after sorting,
// i.e., there exists a dimension d between B[0] and B[-1] in
// sorted order.
//
// By (*), we know that
// stride[B[0]]
// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
// < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d]
// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
// <= \sum{j > B[0]} (size[j] - 1) * stride[j],
//
// where the first < comes from sorting and
// the second <= comes from the fact that dimension d
// exists after step (1) and
// thus must have size greater
// than 1
// the third <= comes from the fact that each term in
// the sum is non-negative
//
// Then we have a countradiction as the invariant must not be
// satisfied at B[0]. So the original proposition is true.
//
// Now that we established the above claim (***), we consider the
// view operation as first sorting the dimensions (i.e., blocks),
// apply the original view (since it only cares dimensions being
// consecutive and contiguous withtin each block), and then undo
// the sort.
//
// Consider a single block B in the output,
// ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
// start--^^^^ ^^^^^^^^^^^^--end
//
// By (*), we know that for all i
// stride[i] = stride[B[-1]] +
// \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
//
// Then the invariant is obviously satisfied at every dimension
// in this block if it is satisfied at dimnesion B[-1]. It only
// remains to show that it is satisfied at the last dimension in
// each block.
//
// Since the same blocks are present in both input and output
// with the same ordering, we will abuse the notation in the
// following statements.
//
// By (*), we know that the following holds for both input and
// output, for any block B:
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
// = \sum_{block B' after B} numel(B') * stride[B'[-1]].
// ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
// By (**), we know that, this quantity in the above equation
// remains the same in input and output. So both
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// and
// stride[B[-1]]
// are the same in input and output.
//
// These two quantities are exactly the LHS and RHS of the
// invariant inequality. Since by assumption the invariant is
// satisfied in input at B[-1], it is also satisfied in output at
// B[-1]. This concludes the proof.
//
// squeeze: Special case of view
//
// unsqueeze: Special case of view
//
// slice: Consider slicing dimension i with step = k >= 1.
//
// Let stride' and size' be the output strides and sizes. We have
//
// stride'[i] = k * stride[i]
// size'[i] <= floor(size[i] / k)
//
// If size'[i] = 1, invariant is obviously satisfied as we are
// just removing a dimension (afte step (1)).
//
// Assume size'[i] > 1.
//
// By assumption, the invariant is satisfied at every dimension
// in input.
//
// For any dimension j, if stride[j] > stride[i], we have
// stride'[j] = stride[j]
// > (size[i] - 1) * stride[i]
// = (size[i] / k * k - 1) * k * stride[i] / k
// = (size[i] / k - 1 / k) * stride'[i]
// >= (size'[i] - 1 / k) * stride'[i]
// >= stride'[i].
//
// If stride[j] < stride[i], we have
// stride'[j] = stride[j] < stride[i] <= stride'[i].
//
// So the sorting order remains unchanged after slice.
//
// Since
// (size'[i] - 1) * stride'[i]
// = (floor(size[i] / k) - 1) * k * stride[i]
// <= (size[i] / k - 1) * k * stride[i]
// = (size[i] - k) * stride[i]
// <= (size[i] - 1) * * stride[i],
// the term from this dimension i in the invariant inequality at
// other dimensions can only decrease after slice. So the
// invariant is preserved.
//
// narrow: Special case of slice
//
// select: narrow + squeeze
//
// permute: Sorting makes permutation of dimensions irrelevant
//
// transpose: Sorting makes swapping dimensions irrelevant
//
// diagonal: Effectively merging two dimensions i and j into a new
// dimension k s.t.
// stride'[k] = stride[i] + stride[j]
// size'[k] <= min(size[i], size[j]),
// where stride and size are on the input, and stride' and size'
// are on the output.
//
// Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
// then this is unsqueeze on that dimension.
//
// WLOG, say stride[i] >= stride[j].
//
// Each dimension d in input with stride[d] > stride[j] has
// stride'[d] = stride[d]
// > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
// >= stride[i] + stride[j]
// = stride[k].
// So, considering the sorted dimensions, this is effectively
// removing i, and replacing j with k.
//
// For dimensions d with stride[i] < stride[d] < stride[j], the
// term from dimension i is removed in the invariant inequality.
// For dimensions d with stride[d] > stride[j], we have
// (size'[k] - 1) * stride'[k]
// <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
// <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
// so the term from i and j in the invariant can only decrease.
//
// So this is generally relaxing the constraint, and thus it
// preserves it.
// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) {
if (sizes.size() > 0) {
std::vector<std::size_t> argsort(sizes.size());
std::iota(argsort.begin(), argsort.end(), 0);
std::sort(argsort.begin(), argsort.end(),
[&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });
int64_t max_index_in_slice = 0;
for (auto i : argsort) {
auto stride_ = strides[i];
if (stride_ <= max_index_in_slice) {
return true;
}
max_index_in_slice += stride_ * (sizes[i] - 1);
}
}
return false;
}
// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
// Helper for as_strided_backward
static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
int64_t storage_size = storage_offset + 1;
int64_t dim = sizes.size();
for (int64_t i = 0; i < dim; i++) {
auto size_i = sizes[i];
if (size_i == 0) {
return storage_offset;
}
storage_size += (size_i - 1) * strides[i];
}
return storage_size;
}
// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_) {
// For output geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// reduce grad on expanded dims (stride=0, size>1)
// Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
auto odim = grad.dim();
std::vector<int64_t> out_sizes_, out_strides_;
out_sizes_.reserve(odim);
out_strides_.reserve(odim);
for (int64_t i = odim - 1; i >= 0; i--) {
auto size_i = sizes[i];
auto stride_i = strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i == 1) {
grad = grad.squeeze(i);
} else if (stride_i == 0) {
grad = grad.sum(i, false);
} else {
out_sizes_.insert(out_sizes_.begin(), size_i);
out_strides_.insert(out_strides_.begin(), stride_i);
}
}
// Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
// For input geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto idim = input_geometry.dim();
IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
std::vector<int64_t> inp_sizes_, inp_strides_;
inp_sizes_.reserve(idim);
inp_strides_.reserve(idim);
for (int64_t i = idim - 1; i >= 0; i--) {
auto size_i = inp_sizes[i];
auto stride_i = inp_strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i != 1) {
inp_sizes_.insert(inp_sizes_.begin(), size_i);
inp_strides_.insert(inp_strides_.begin(), stride_i);
}
}
// Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
// Rest of this function implements
// Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// TODO: Raise if not all output values are visible in input geometry.
// Technically speaking, if you treat those values as constants, not
// raising is fine, and mathematically correct. However, these values
// really are contained in some base tensor, and by treating them as
// constants we are ignoring this tight dependency. Therefore, it is
// more sensible to raise here.
// Step (1): create underlying tensor as "storage"
auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
auto out_effective_offset = storage_offset - shared_offset;
auto base_size = std::max(
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
_min_storage_size(out_sizes_, out_strides_, out_effective_offset)
);
auto storage = at::zeros({base_size}, grad.options());
// prepare indices tensor if we will do index_add_ later
c10::optional<at::Tensor> flatten_full_indices;
if (inp_maybe_overlap || out_maybe_overlap) {
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
}
// Step (2): use output geometry to scatter gradients into storage
if (out_maybe_overlap) {
auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
} else {
// assume that new tensors have 0 storage offset
storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
}
// Step (3): if input tensor has overlapping memory, divide scattered gradient
// at storage[i] by the number of times i shows up in input geometry
if (inp_maybe_overlap) {
auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
storage.div_(count); // this will give nan outside visible range
}
// Step (4): return as_strided view of the storage tensor with input geometry
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
}
auto recip = (self * self + other * other).reciprocal();
return std::tuple<Tensor,Tensor>{
output_mask[0] ? grad * other * recip : Tensor(),
output_mask[1] ? grad * -self * recip : Tensor() };
}
// TODO: Seriously consider writing the derivative formulas for
// each output separately; there is not all that much sharing
// of computation going on here.
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
const Tensor & grad_grad_input,
const Tensor & grad_grad_weight,
const Tensor & grad_out,
const Tensor & input_,
const Tensor & weight_) {
if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) {
return std::tuple<Tensor, Tensor, Tensor>(Tensor(), Tensor(), Tensor());
}
auto input = input_.contiguous();
auto weight = weight_.contiguous();
// Zero-fill undefined grads (TODO: do this more efficiently)
auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto positive_mask = (input > 0).type_as(ggI);
auto nonpositive_mask = (input <= 0).type_as(ggW);
// Explanation: Let input be i, weight be w, grad_output be gO.
// f(i, w) = i if i > 0
// = w * i if i <= 0
// gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0
// = gO * w if i <= 0 = gO * i if i <= 0
// The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.
if (weight.numel() == 1) {
// from PReLU.forward: num_parameters == 0 is used indicate that a
// single weight is shared among all input channels.
// this is a little tricky because PReLU currently doesn't take a shape so the weight may be
// 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
// and tensor are merged). So, use weight and ggW as 0-dim in this case.
bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;
auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
return std::tuple<Tensor, Tensor, Tensor>(
ggO,
ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
(ggI * gO * nonpositive_mask).sum().expand_as(weight)
);
} else {
// Expand ggW to match size of ggI; a simple expand doesn't work because
// ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example,
// let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1)
// so the expand succeeds.
auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
auto ggW_expanded = ggW;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
ggW_expanded = ggW_expanded.unsqueeze(1);
}
ggW_expanded = ggW_expanded.expand_as(ggI);
auto gI = ggW_expanded * gO * nonpositive_mask;
auto gW = ggI * gO * nonpositive_mask;
if (input.dim() > 1) {
gW = gW.sum(0);
}
while (gW.dim() > 1) {
gW = gW.sum(1);
}
Tensor ggO;
if (gO.requires_grad()) {
// expand weight as input as in ggW/ggI above
auto weight_expanded = weight;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
weight_expanded = weight_expanded.unsqueeze(1);
}
weight_expanded = weight_expanded.expand_as(input);
auto mask = positive_mask + nonpositive_mask * weight_expanded;
ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
}
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
}
}
// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
TORCH_CHECK(compute_uv,
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
auto m = self.size(-2);
auto n = self.size(-1);
auto k = sigma.size(-1);
auto gsigma = grads[1];
auto u = raw_u;
auto v = raw_v;
auto gu = grads[0];
auto gv = grads[2];
if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k);
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
if (gv.defined()) {
gv = gv.narrow(-1, 0, k);
}
}
auto vt = v.transpose(-2, -1);
Tensor sigma_term;
if (gsigma.defined()) {
sigma_term = at::matmul(u, at::matmul(gsigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), vt));
} else {
sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// in case that there are no gu and gv, we can avoid the series of kernel
// calls below
if (!gv.defined() && !gu.defined()) {
return sigma_term;
}
auto ut = u.transpose(-2, -1);
auto im = at::eye(m, self.options());
auto in = at::eye(n, self.options());
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
auto sigma_sq = sigma.pow(2);
auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
// The following two lines invert values of F, and fills the diagonal with 0s.
// Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
// first to prevent nan from appearing in backward of this function.
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F = F.pow(-1);
Tensor u_term, v_term;
if (gu.defined()) {
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(ut, gu) - at::matmul(gu.transpose(-2, -1), u)), sigma_mat));
if (m > k) {
u_term = u_term + at::matmul(im - at::matmul(u, ut), at::matmul(gu, sigma_mat_inv));
}
u_term = at::matmul(u_term, vt);
} else {
u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (gv.defined()) {
auto gvt = gv.transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vt, gv) - at::matmul(gvt, v)), vt));
if (n > k) {
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvt, in - at::matmul(v, vt)));
}
v_term = at::matmul(u, v_term);
} else {
v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
return u_term + sigma_term + v_term;
}
// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation"
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, const Tensor& lambda, const Tensor& v) {
// This gradient only works for real eigenvalues at the moment.
TORCH_CHECK(eigenvectors,
"eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)");
auto zeros = at::zeros({1}, lambda.options());
TORCH_CHECK(
at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros),
"eig_backward: Backward calculation does not support complex eigenvalues at the moment.");
auto glambda = grads[0];
auto gv = grads[1];
auto vt = v.transpose(-2, -1);
Tensor result;
// contribution from the eigenvectors
if (gv.defined()) {
auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1);
auto hm = rlambda.transpose(-2,-1) - rlambda;
hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
hm.pow_(-1.0);
auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v;
auto B = hm * at::matmul(vt, gvortho);
auto A = at::matmul(B, vt);
std::tie(result, std::ignore) = at::solve(A, vt);
}
// contribution from eigenvalues
if (glambda.defined()) {
auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt;
auto A = at::matmul(v, grlambda);
auto vvt = at::matmul(v, vt);
if (result.defined()) {
Tensor result1;
std::tie(result1, std::ignore) = at::solve(A, vvt);
result = result.add(result1);
}
else {
std::tie(result, std::ignore) = at::solve(A, vvt);
}
}
return result;
}
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
// This gradient is symmetric, and not triangular.
// symeig operates only on symmetric inputs, which is a subspace of
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
// elements. We resolve this by taking the gradient of the functionally independent
// elements of the matrix (i.e., the lower triangular portion of the input) and then
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
// the symeig operation. The motivation behind this choice is that symmetric gradient
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
TORCH_CHECK(eigenvectors,
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
auto glambda = grads[0];
auto gv = grads[1];
auto vt = v.transpose(-2, -1);
Tensor result;
if (gv.defined()) {
Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1);
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F.pow_(-1);
F.mul_(at::matmul(vt, gv));
result = at::matmul(v, at::matmul(F, vt));
} else {
result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (glambda.defined()) {
result.add_(at::matmul(at::matmul(v, at::diag_embed(glambda, /*offset=*/0, /*dim1=*/-2, /*dim2=*/-1)), vt));
}
return result.add(result.transpose(-2, -1)).mul_(0.5);
}
// We refer Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
// Algebra Functions with Application in Optimum Experimental Design (Extended Version)
// The derivative for the QR decomposition is adapted from Eq. 42 of the
// above reference.
Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, const Tensor& Q, const Tensor& R) {
auto grad_Q = grads[0];
auto grad_R = grads[1];
TORCH_CHECK(R.size(-2) == R.size(-1),
"The derivative when R is non-square is not implemented. ");
// Compute R (R')^{T}
Tensor R_term;
if (grad_R.defined()) {
R_term = at::matmul(R, grad_R.transpose(-2, -1));
} else {
// R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// Compute Q^{T} Q'
Tensor Q_term;
if (grad_Q.defined()) {
Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
} else {
// Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// We want to compute: (rhs_solve_1 . R^{-T})
// Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
// Since R is upper triangular, we can do this using
// triangular_solve(rhs_solve_1^{T}, R)^{T}
auto rhs_solve_1 = R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
Tensor solve_soln_1;
std::tie(solve_soln_1, std::ignore) = at::triangular_solve(rhs_solve_1.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
Tensor grad_A;
if (grad_R.defined()) {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
} else {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
}
// Successive computations involve computation of QQ^{T} which is identity when A is square
if (self.size(-1) != self.size(-2)) {
Tensor rhs_solve_2;
// We use the same trick from above for this computation
if (grad_Q.defined()) {
rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
} else {
rhs_solve_2 = -at::matmul(Q, Q_term);
}
Tensor solve_soln_2;
std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
grad_A.add_(solve_soln_2.transpose(-2, -1));
}
return grad_A;
}
// Invertible case is derived from Jacobi's formula, and also can be found at:
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1));
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (det.item<double>() == 0) {
return singular_case_backward(grad, self, det);
} else {
return nonsingular_case_backward(grad, self, det);
}
} else {
auto nonzero_det_indices = at::where(det);
if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular)
return nonsingular_case_backward(grad, self, det);
}
auto zero_det_indices = at::where(det == 0);
if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular)
return singular_case_backward(grad, self, det);
}
Tensor grad_det = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_det.index_put_(/*indices=*/nonzero_det_indices,
/*value=*/nonsingular_case_backward(grad.index(nonzero_det_indices),
self.index(nonzero_det_indices),
det.index(nonzero_det_indices)));
// non-invertible case, uses SVD
grad_det.index_put_(/*indices=*/zero_det_indices,
/*value=*/singular_case_backward(grad.index(zero_det_indices),
self.index(zero_det_indices),
det.index(zero_det_indices)));
return grad_det;
}
}
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// logdet = \sum log(sigma)
auto gsigma = grad.unsqueeze(-1).div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (logdet.item<double>() != -INFINITY) {
return nonsingular_case_backward(grad, self);
} else {
return singular_case_backward(grad, self);
}
} else {
auto finite_logdet_indices = at::where(logdet != -INFINITY);
if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad, self);
}
auto neginf_logdet_indices = at::where(logdet == -INFINITY);
if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad, self);
}
Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_logdet.index_put_(/*indices=*/finite_logdet_indices,
/*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices),
self.index(finite_logdet_indices)));
// non-invertible case, uses SVD
grad_logdet.index_put_(/*indices=*/neginf_logdet_indices,
/*value=*/singular_case_backward(grad.index(neginf_logdet_indices),
self.index(neginf_logdet_indices)));
return grad_logdet;
}
}
Tensor slogdet_backward(const Tensor& grad_logabsdet,
const Tensor& self,
const Tensor& signdet, const Tensor& logabsdet) {
auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// sigma has all non-negative entries (also with at least one zero entry)
// so logabsdet = \sum log(abs(sigma))
// but det = 0, so backward logabsdet = \sum log(sigma)
auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (signdet.item<double>() == 0) {
return singular_case_backward(grad_logabsdet, self);
} else {
return nonsingular_case_backward(grad_logabsdet, self);
}
} else {
auto nonzero_signdet_indices = at::where(signdet);
if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad_logabsdet, self);
}
auto zero_signdet_indices = at::where(signdet == 0);
if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad_logabsdet, self);
}
Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
/*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
self.index(nonzero_signdet_indices)));
// non-invertible case, uses SVD
grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
/*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
self.index(zero_signdet_indices)));
return grad_slogdet;
}
}
// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor & grad_x, const Tensor & grad_m,
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
Tensor grad_b, grad_a;
if (grad_x.defined() || grad_m.defined()) {
if (grad_x.defined()) {
grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
if (upper) {
grad_a = grad_a.triu((int) unitriangular);
} else {
grad_a = grad_a.tril(-((int) unitriangular));
}
}
}
if (!grad_a.defined()) {
grad_a = at::zeros({1}, a.options()).expand_as(a);
}
if (!grad_b.defined()) {
grad_b = at::zeros({1}, b.options()).expand_as(b);
}
if (output_mask[1] && grad_m.defined()) {
grad_a = grad_a.add(grad_m);
}
}
return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}
std::tuple<Tensor, Tensor> cholesky_solve_backward(
const Tensor& grad_x, const Tensor& self,
const Tensor& input2, const Tensor& result, const bool upper) {
Tensor grad_self, grad_input2;
if (grad_x.defined()) {
grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);
Tensor common_term = at::matmul(grad_self, result.transpose(-2, -1));
common_term = common_term + common_term.transpose(-2, -1);
if (upper) {
grad_input2 = -at::matmul(input2, common_term);
} else {
grad_input2 = -at::matmul(common_term, input2);
}
}
return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
}
// Generally speaking, fft's backward is ifft.
Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim,
bool complex_input, bool complex_output,
bool inverse, IntArrayRef checked_signal_sizes,
bool normalized, bool onesided,
IntArrayRef output_sizes) {
Tensor gI;
if (!complex_input && complex_output) {
// Forward is R2C
// Do inverse C2C and project onto real plane because grad can be
// asymmetrical so C2R can't be used.
if (onesided) {
// Forward is R2C (onesided)
// Think of onesided R2C rfft as
// 1. view as complex numbers (fill complex dim with zeros)
// 2. C2C fft
// 3. discard half of results
// So backward is
// 1. fill the other half with zeros (with `zero_grad_shape` below)
// (C2C ifft only take twosided inputs so we need to fill here)
// 2. inverse C2C ifft
// 3. discard the complex dim
int64_t zero_length = checked_signal_sizes[signal_ndim - 1] - grad.size(signal_ndim);
auto complex_full_grad = grad;
if (zero_length > 0) {
std::vector<int64_t> zero_grad_shape(signal_ndim + 2);
zero_grad_shape[0] = self.size(0);
for (int64_t i = 1; i < signal_ndim; i++) {
zero_grad_shape[i] = checked_signal_sizes[i - 1];
}
zero_grad_shape[signal_ndim] = zero_length;
zero_grad_shape[signal_ndim + 1] = 2;
complex_full_grad = at::cat({ grad, at::zeros(zero_grad_shape, grad.options()) }, signal_ndim);
}
gI = _fft_with_size(complex_full_grad, signal_ndim,
/* complex_input */ true, /* complex_output */ true,
!inverse, checked_signal_sizes, normalized,
/* onesided */ false, complex_full_grad.sizes()).select(-1, 0);
} else {
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ true,
/* complex_output */ true, !inverse,
checked_signal_sizes, normalized,
/* onesided */ false, grad.sizes()).select(-1, 0);
}
} else if (complex_input && !complex_output && onesided) {
// Forward is C2R (onesided)
// Think of onesided C2R irfft as
// 1. fill the other half by conjugate symmetry
// 2. inverse C2C ifft
// 3. discard the complex dimension
// So backward is
// 1. R2C rfft (essentially add dummy complex dimension, and dft)
// 2. accumulate gradient by conjugate symmetry
// since rfft results follow conjugate symmetry, we only need to
// double some entries from onesided rfft results, i.e., the ones with
// their reflected indices also landing out of the onesided range. So
// consider the index of last dim:
// i. idx = 0.
// Reflected to (N - 0) % N = 0. Not doubled.
// ii 0 < idx < floor(N/2) (last).
// N > N - idx > ceil(N/2)
// Reflected to ()
// iii. idx = floor(N/2) = N/2 (last) when N even.
// Reflected to (N - N/2) % N = N/2. Not doubled.
// iv. idx = floor(N/2) = (N-1)/2 (last) when N odd.
// Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
// Therefore, needs to double
// idx = 1, 2, ..., N/2 - 1 when N even
// idx = 1, 2, ..., (N-1)/2 when N odd
// that is
// idx = 1, 2, ..., N - (floor(N/2) + 1)
// = 1, 2, ..., N - onesided_length
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ false,
/* complex_output */ true, /* inverse */ false,
checked_signal_sizes, normalized, /* onesided */ true,
self.sizes());
int64_t double_length = checked_signal_sizes[signal_ndim - 1] - self.size(signal_ndim);
if (double_length > 0) { // also covers case when signal size is zero
gI.narrow(signal_ndim, 1, double_length).mul_(2);
}
} else {
gI = _fft_with_size(grad, signal_ndim, complex_output, complex_input,
!inverse, checked_signal_sizes, normalized, onesided,
self.sizes());
}
if (normalized) {
// If normalized, backward is exactly calling fft with inversed argument as
// the forward because both are unitary.
return gI;
} else {
// If not normalized, in backward, we need to upscale or downscale gI basing
// on whether the forward is an inverse fft.
auto signal_numel = std::accumulate(checked_signal_sizes.begin(),
checked_signal_sizes.end(), 1, std::multiplies<int64_t>());
if (!inverse) {
return gI.mul_(static_cast<double>(signal_numel));
} else {
return gI.div_(static_cast<double>(signal_numel));
}
}
}
// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
auto r = to_sum.sum(0, keepdim);
int64_t start_point_exclusive = keepdim ? 1 : 0;
for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
r = r.sum(dim, keepdim);
}
return r;
}
// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
if (src_expanded.sizes().size() == target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(0);
}
return src_expanded;
}
// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
return src_expanded.expand_as(target);
}
std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
const Tensor & input,
const c10::optional<Tensor> & gamma,
const Tensor & ggI,
const Tensor & ggG,
const Tensor & ggB,
const Tensor & gO,
const c10::optional<Tensor> & running_mean,
const c10::optional<Tensor> & running_var,
bool training,
double eps,
const c10::optional<Tensor> & save_mean,
const c10::optional<Tensor> & save_invstd,
std::array<bool,3> output_mask) {
bool affine = isDefined(gamma);
// TODO: Do we have a ScalarOrTensor type? Would such a thing exist?
Tensor gamma_expanded;
Tensor ggG_expanded, ggB_expanded;
if (affine) {
gamma_expanded = expand_as_dim1(*gamma, input);
if (ggG.defined()) {
ggG_expanded = expand_as_dim1(ggG, input);
}
if (ggB.defined()) {
ggB_expanded = expand_as_dim1(ggB, input);
}
} else {
gamma_expanded = at::ones({}, input.options());
}
// define some terms we will reuse
auto M = input.size(0);
for (auto s : input.sizes().slice(2)) {
M *= s;
}
// for half inputs, save_mean, save_invstd are float (ideally, we would cast
// everything else, but not now)
auto mu = unsqueeze_dim1(training ? toLegacyTensor(save_mean).to(input.scalar_type()) : toLegacyTensor(running_mean), input);
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
training ? toLegacyTensor(save_invstd).to(input.scalar_type())
: toLegacyTensor(running_var).add(Scalar(eps)).pow(-0.5),
input);
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
// calculate gI
auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
auto gO_sum = sum_exclude_dim1(gO);
Tensor gI;
if (ggI.defined() && training) {
auto ggI_sum = sum_exclude_dim1(ggI);
auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
(sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
}
// add contribution of gamma term to gI
Tensor gI_G_term;
if (affine && ggG.defined()) {
if (training) {
auto t0 = gO * sigma2_eps_neg_1_2;
auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
} else {
gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
}
}
// this is the first backward's grad_input
auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
return h0 * h1;
};
// calculate gG
Tensor gG;
if (affine && ggI.defined()) {
if (training) {
// gG is just the first backwards with the gamma term removed (then shaped properly)
gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
gG = sum_exclude_dim1(gG, false);
} else {
gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
}
}
// calculate ggO
Tensor ggO;
// contribution of input term
if (ggI.defined()) {
if (training) {
ggO = first_back_grad_input(ggI, gamma_expanded);
} else {
ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
}
}
if (ggG.defined()) {
auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
}
if (ggB.defined()) {
auto ggO_B_term = ggB_expanded;
ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
}
if (output_mask[1] && !gG.defined()) {
AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
}
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_layer_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
int64_t M,
int64_t N,
double eps,
std::array<bool, 3> grad_input_mask) {
Tensor dX;
Tensor dgamma;
Tensor dbeta;
const Tensor X_tensor = X.reshape({M, N});
const Tensor mean_tensor = mean.reshape({M, 1});
const Tensor rstd_tensor = rstd.reshape({M, 1});
const double s = 1.0 / static_cast<double>(N);
Tensor dY_tensor;
if (dY.defined()) {
dY_tensor = dY.reshape({M, N});
}
if (grad_input_mask[0]) {
Tensor gamma_tensor;
if (isDefined(gamma)) {
gamma_tensor = gamma->reshape({1, N});
}
Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
Tensor var;
Tensor dvar;
if (drstd.defined()) {
var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
dvar = -0.5 * rstd_cube * drstd.view({M, 1});
}
Tensor ds;
Tensor db;
if (dY.defined()) {
ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
: dY_tensor * X_tensor)
.sum(1)
.unsqueeze_(-1);
db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
.sum(1)
.unsqueeze_(-1);
const Tensor& a = rstd_tensor;
const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
if (isDefined(gamma)) {
dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
} else {
dX = a * dY_tensor + b * X_tensor + c;
}
if (dmean.defined() && drstd.defined()) {
dX += var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
{1},
false,
true,
false);
}
dX = dX.reshape_as(X);
} else if (dmean.defined() && drstd.defined()) {
dX = var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
{1},
false,
true,
false)
.reshape_as(X);
}
}
if (grad_input_mask[1] && dY.defined()) {
dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
.sum(0)
.reshape_as(toLegacyTensor(gamma));
}
if (grad_input_mask[2] && dY.defined()) {
dbeta = dY_tensor.sum(0).reshape_as(toLegacyTensor(gamma));
}
return std::make_tuple(dX, dgamma, dbeta);
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
double eps,
std::array<bool, 3> grad_input_mask) {
const int64_t G = group;
const int64_t D = C / G;
const double s = 1.0 / static_cast<double>(D * HxW);
Tensor dX;
Tensor dgamma;
Tensor dbeta;
const Tensor X_tensor = X.reshape({N, G, D, HxW});
const Tensor mean_tensor = mean.reshape({N, G, 1, 1});
const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1});
Tensor dY_tensor;
Tensor ds;
Tensor db;
if (dY.defined()) {
dY_tensor = dY.reshape({N, G, D, HxW});
ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
db = dY_tensor.sum(3).unsqueeze_(-1);
}
if (grad_input_mask[0]) {
Tensor gamma_tensor;
if (isDefined(gamma)) {
gamma_tensor = gamma->reshape({1, G, D, 1});
}
const Tensor var =
((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
Tensor dvar;
if (drstd.defined()) {
dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1});
}
if (dY.defined()) {
const Tensor a =
isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
.unsqueeze_(-2);
Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
.unsqueeze_(-2);
b = (c * mean_tensor - b) * rstd_cube * s;
c = -b * mean_tensor - c * rstd_tensor * s;
dX = a * dY_tensor + b * X_tensor + c;
if (dmean.defined() && drstd.defined()) {
dX += var_std_mean_backward(
{dvar, dmean.view({N, G, 1, 1})},
X_tensor,
var,
mean_tensor,
{2, 3},
false,
true,
false);
}
dX = dX.reshape_as(X);
} else if (dmean.defined() && drstd.defined()) {
dX = var_std_mean_backward(
{dvar, dmean.view({N, G, 1, 1})},
X_tensor,
var,
mean_tensor,
{2, 3},
false,
true,
false)
.reshape_as(X);
}
}
if (grad_input_mask[1] && dY.defined()) {
dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toLegacyTensor(gamma));
}
if (grad_input_mask[2] && dY.defined()) {
dbeta = db.sum(0).reshape_as(toLegacyTensor(gamma));
}
return std::make_tuple(dX, dgamma, dbeta);
}
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
IntArrayRef sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) {
Tensor grad_i1, grad_i2, grad_i3;
if (grad_out.defined()) {
if (grad_mask[0])
grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
if (grad_mask[1])
grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
if (grad_mask[2])
grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
}
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}
Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
if (self.is_sparse()) {
AT_ERROR(
"log1p of a sparse tensor is made to be non-differentiable since ",
"local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
"Use a different mathematical operation which preserves sparsity of gradients, ",
"or report a bug if you think this is an error.");
}
return grad / (self + 1);
}
Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) {
// TODO: improve this backward by writing a kernel (maybe)
auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out;
auto full_size = sparse_grad_out.sizes();
auto flattened_grad_shape = values_shape.vec();
flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0)));
auto flattened_dense_grad = dense_grad.view(flattened_grad_shape);
auto flattened_indices = at::sparse::flatten_indices(indices, full_size);
return flattened_dense_grad.index_select(0, flattened_indices);
}
// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
auto negated_pad = pad.vec();
std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
return at::constant_pad_nd(grad, negated_pad, 0);
}
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) {
// since first backward takes care of padding_idx
// and scaling by frequency, we don't need to worry
// about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));
// reshape gradient as per the shape of indices
auto size = indices.sizes().vec();
size.push_back(-1);
return gg_weight.view(size);
}
Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) {
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}
Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
if (zero_infinity) {
return at::where(
loss.unsqueeze(0).unsqueeze(2) == 0,
at::zeros({0}, raw_grad.options()),
raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
} else {
return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
}
}
bool any_variable_defined(variable_list& variables) {
for (auto variable : variables) {
if (variable.defined()) {
return true;
}
}
return false;
}
} // anonymous namespace
${autograd_function_definitions}
}}} // namespace torch::autograd::generated