blob: 79d195d73a7428e701d9a1db5851ccf6f1e5b598 [file] [log] [blame]
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/ATen.h>
#include <ATen/Utils.h>
#include <c10/core/TensorOptions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/core/Reduction.h>
#include <ATen/BatchedTensorImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarOps.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <ciso646>
#include <algorithm>
#include <numeric>
#include <functional>
// Helper functions for autogenerated code
// These used to be inlined into the codegened Functions.cpp
namespace torch {
namespace autograd {
namespace generated {
namespace details {
using at::Tensor;
using at::Scalar;
using at::IntArrayRef;
using at::TensorList;
bool isDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
bool isFwGradDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined();
}
Tensor toLegacyTensor(const c10::optional<Tensor>& t) {
return t.has_value() ? *t : Tensor();
}
Tensor toLegacyFwGrad(const c10::optional<Tensor>& t) {
return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor();
}
Tensor toLegacyPrimal(const c10::optional<Tensor>& t) {
return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor();
}
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
out[range.first] = t;
}
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
std::copy(t.begin(), t.end(), out.begin() + range.first);
}
Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) {
auto ratio = result / self;
ratio.masked_fill_(self == 0, 0);
return grad * ratio;
}
Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
}
Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
bool is_one = false;
if (s.isFloatingPoint()) {
is_one = s.toDouble() == 1;
} else if(s.isIntegral(true)) {
is_one = s.toLong() == 1;
}
if (is_one) {
return t;
} else {
return t * s;
}
}
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
int64_t size = 1;
if (sizes.size() == 0) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, sizes.size());
size *= sizes[d];
}
return size;
}
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
// R -> C
return at::real(gradient_result);
}
return gradient_result;
}
Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
if (!self.is_complex() && gradient_result.is_complex()) {
// R -> C
return at::real(gradient_result);
}
return gradient_result;
}
Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) {
if (keepdim) {
return output;
}
int64_t total_dims = output.dim() + dims.size();
std::vector<int64_t> target_shape(total_dims, 0);
for (int64_t i : dims) {
if (i < 0) {
i = total_dims + i;
}
target_shape[i] = 1;
}
int64_t j = 0;
for (int64_t i : output.sizes()) {
while (target_shape[j] > 0) j++;
target_shape[j++] = i;
}
return output.reshape(target_shape);
}
Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) {
return (grad / mask.sum(dims, true)) * mask;
}
std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
}
// handle case at 0 where we return a subgradient containing 0
Tensor ratio = grad / res;
ratio.masked_fill_(res == 0, 0);
return std::tuple<Tensor, Tensor>{
x1 * ratio.sum(-1, true) - ratio.matmul(x2),
x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)};
}
Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional<Scalar> & p_, const Tensor& norm) {
return norm_backward(grad, self, p_, norm, {}, true);
}
Tensor norm_backward(Tensor grad, const Tensor& self, const optional<Scalar> & p_, Tensor norm, IntArrayRef dim, bool keepdim) {
size_t ndim = self.sizes().size();
double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;
Tensor scale_v;
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, ndim);
norm = unsqueeze_multiple(norm, dim, ndim);
}
if (p == 0.0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else if (p == 1.0) {
return self.sgn() * grad;
} else if (p == 2.0) {
self_scaled = self;
scale_v = grad / norm;
} else if (std::isinf(p)) {
Tensor is_eq_max = (self.abs() == norm).logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self);
self_scaled = self.sign() * is_eq_max;
Tensor nb_max = is_eq_max.count_nonzero(dim);
if (self.dim() != 0) {
nb_max = unsqueeze_multiple(nb_max, dim, ndim);
}
scale_v = grad / nb_max;
} else if (p < 2.0) {
self_scaled = self.sgn() * self.abs().pow(p - 1);
scale_v = grad / norm.pow(p - 1);
} else {
self_scaled = self * self.abs().pow(p - 2);
scale_v = grad / norm.pow(p - 1);
}
// handle case at 0 where we return a subgradient containing 0
scale_v.masked_fill_(norm == 0, 0);
return self_scaled * scale_v;
}
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {
if (exponent.equal(0.0)) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); };
Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble());
return handle_r_to_c(self, out);
}
}
Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj());
return handle_r_to_c(self, out);
}
// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
Tensor cond;
if (exponent.is_complex()) {
auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
cond = at::logical_and(self == 0, is_real_exp);
} else {
cond = at::logical_and(self == 0, exponent >= 0);
}
auto out = grad * at::where(cond,
at::zeros({}, grad.options()),
(result * self.log()).conj());
return handle_r_to_c(exponent, out);
}
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); };
if (base.equal(0.0)) {
auto cond = [](auto exp) {
if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
} else {
return exp >=0;
}
};
auto out = grad * at::where(cond(exponent),
at::zeros({}, grad.options()),
grad_lambda(result, base));
return handle_r_to_c(exponent, out);
} else {
auto out = grad * grad_lambda(result, base);
return handle_r_to_c(exponent, out);
}
}
Tensor angle_backward(Tensor grad, const Tensor& self) {
if (self.is_complex()) {
return at::where(self == 0.0, at::zeros({}, self.options()),
grad * self / self.abs().pow(2) * Scalar(c10::complex<double>{0.0, 1.0}));
} else {
return at::zeros_like(self, at::MemoryFormat::Preserve);
}
}
Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
args = args.add(self.unsqueeze(-1));
return grad * args.digamma_().sum(-1);
}
Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) {
if (self.is_complex()) {
auto abs = at::abs(self);
// C -> C
// https://arxiv.org/pdf/1701.00392.pdf Section 4.20
return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
} else {
return at::zeros_like(self, at::MemoryFormat::Preserve);
}
}
Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto out = grad * other.conj();
return handle_r_to_c(self_st, out);
}
Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto result = grad / other.conj();
return handle_r_to_c(self_st, result);
}
Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) {
auto result = -grad * ((self / other) / other).conj();
return handle_r_to_c(other, result);
}
Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (size_t i = 0; i < ndims; i++) {
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
}
return grad.permute(dims);
}
Tensor rad2deg_backward(const Tensor& grad) {
constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
return at::mul(grad, wrapped_scalar_tensor(Scalar(M_180_PI)));
}
Tensor deg2rad_backward(const Tensor& grad) {
constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
return at::mul(grad, wrapped_scalar_tensor(Scalar(M_PI_180)));
}
Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
Tensor res = t;
for (size_t i = 0; i < n_dims; i++){
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
}
return res;
}
Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes);
} else {
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes);
}
} else {
return grad.expand(sizes);
}
}
Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) {
auto sizes = self.sizes();
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not();
} else {
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes) * self.isnan().logical_not();
}
} else {
return grad.expand(sizes) * self.isnan().logical_not();
}
}
std::vector<int64_t> reverse_list(const IntArrayRef list) {
auto result = std::vector<int64_t>();
result.reserve(list.size());
for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
result.push_back(*iter);
}
return result;
}
Tensor reverse_dim(const Tensor& t, int64_t dim) {
Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
return t.index_select(dim, index);
}
Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
if (inp.size(dim) == 1) {
return grad;
}
auto ones_size = inp.sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones(ones_size, grad.options());
Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
return grad * (exclusive_normal * exclusive_reverse);
}
// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input: [ a, b, c]
// cumprod(exclusive, normal): [1 , a, a * b]
// cumprod(exclusive, reverse): [b * c, c, 1]
// product: [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
if (input.dim() == 0) {
return grad;
}
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.numel() == 0) {
return (grad * result) / input;
} else if (zero_idx.size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
}
}
Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
if (!keepdim && input.dim() != 1) {
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
}
Tensor zero_mask = (input == 0);
Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) {
return (grad * result) / input;
} else {
return prod_safe_zeros_backward(grad, input, dim);
}
}
Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
return at::linalg_solve(A.conj().transpose(-2, -1), grad);
}
Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
Tensor grad_self = solve_backward_self(grad, self, A);
if (self.ndimension() == 2 && A.ndimension() == 2) {
return -at::mm(grad_self, solution.conj().transpose(-2, -1));
}
// if self was unsqueezed from (..., M) to (..., M, 1)
auto batched_rhs_shape = IntArrayRef(A.sizes().data(), A.dim()-1); // A.shape[:-1]
bool is_rhs_broadcasted = self.dim() == 1 || (A.dim()-1 == self.dim() && self.sizes().equals(batched_rhs_shape));
if (is_rhs_broadcasted) {
return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).conj().transpose(-2, -1));
}
return -at::matmul(grad_self, solution.conj().transpose(-2, -1));
}
Tensor cumsum_backward(const Tensor & x, int64_t dim) {
// Need to check numel to see if there are no values (such as shape [0,2], and dim to see if x is a scalar.
if (x.dim() == 0 || x.numel() == 0) {
return x;
}
auto ret = at::cumsum(-x, dim);
auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(at::MemoryFormat::Preserve);
ret -= ret_sum.expand(ret.sizes());
ret += x;
return ret;
}
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
result = unsqueeze_multiple(result, dim, self.sizes().size());
}
return grad * (self - result).exp();
}
Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) {
if (grad.dim() == 0 || grad.numel() == 0) {
return grad;
}
// Reference: https://github.com/tensorflow/tensorflow/blob/
// 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
return AT_DISPATCH_FLOATING_TYPES(
at::typeMetaToScalarType(grad.dtype()),
"logcumsumexp_backward",
[grad, self, result, dim]() {
auto grad_min = at::empty_like(grad);
grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);
auto reverse_logcumsumexp = [dim](auto x) {
return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
};
auto output_pos =
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
auto output_neg =
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
return output_pos - output_neg;
});
}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
IntArrayRef sizes;
at::TensorOptions o;
for (auto v : grads) {
if (v.defined()) {
sizes = v.sizes();
o = static_cast<Tensor>(v).options();
break;
}
}
auto grads_tensors = fmap(grads, [&](const Variable& v) {
return (
v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
});
return at::stack(grads_tensors, dim);
}
Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
auto result = self;
int64_t nDims = sizes.size();
for (int64_t dim = 0; dim < nDims; dim++) {
if (sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
// unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
}
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
std::vector<Tensor> grad_inputs(sizes.size());
if (!grad.defined()) {
return grad_inputs;
}
dim = at::legacy_cat_wrap_dim(dim, sizes);
int64_t accumulate = 0;
for (size_t i = 0; i < sizes.size(); ++i) {
auto& shape = sizes[i];
// If input was empty tensor, gradInput should be empty tensor.
if (shape == std::vector<int64_t>({0})) {
grad_inputs[i] = at::zeros({0}, grad.options());
continue;
}
auto size = shape[dim];
accumulate += size;
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
}
return grad_inputs;
}
Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
// clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
if (max && min) {
return grad * ((self >= *min) * (self <= *max)).type_as(grad);
} else if (min) {
return grad * (self >= *min).type_as(grad);
} else if (max) {
return grad * (self <= *max).type_as(grad);
} else {
return grad;
}
}
// This function is used by load_derivatives.py to replace tensor.strides()
// calls that appear in derivative formulas. If the tensor has requires_grad
// set, this function returns its strides or throws an error if the tensor
// is sparse. If requires_grad is not set, an empty array is returned since
// there will be no backward pass.
//
// This function only supports the case where `input` is the tensor whose
// single derivative is being calculated.
//
// This function does not support `self` derivatives for inplace functions.
//
// Args:
// input Tensor to call .strides() on
// input_name Name of `input` tensor, from derivative formula
at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) {
// TODO: Ideally, this function would never be called if requires_grad is
// not set. Once codegen is updated to avoid the call, we can remove this
// check.
if (input.requires_grad()) {
TORCH_CHECK(
!input.is_sparse(),
"The backward pass for this operation requires the '", input_name,
"' tensor to be strided, but a sparse tensor was given instead. ",
"Please either use a strided tensor or set requires_grad=False for '",
input_name, "'");
return input.strides();
} else {
return IntArrayRef({});
}
}
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha);
} else {
return maybe_multiply(grad.mm(mat2.t().conj()), alpha);
}
}
Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (strides[0] == 1 && strides[1] == sizes[0]) {
if (mat1.is_sparse()) {
// Since mm(dense, sparse) doesn't exist,
// pass a transposed output matrix to the underlying "addmm"
// function directly.
int64_t out_rows = mat1.size(1);
int64_t out_cols = grad.size(1);
Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true);
Tensor r = at::empty({out_cols, out_rows}, grad.options()).t();
at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
return r;
}
return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha);
} else {
return maybe_multiply(mat1.t().conj().mm(grad), alpha);
}
}
Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) {
AT_ASSERT(sparse_.is_sparse());
auto sparse = sparse_.coalesce();
Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha);
return grad_sparse.sparse_mask(sparse);
}
// This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask`
// and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2,
// and they must have the same shape.
// Note that the `output` must have the same `indices` as the `mask` so we are using just a clone.
// However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values`
// That's why we created this `_sparse_matrix_mask_helper` function.
Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){
Tensor output = at::empty_like(mask);
Tensor mask_indices = mask._indices().clone();
Tensor r_values;
if (mask._nnz() == 0) {
r_values = at::zeros_like(mask._values());
} else {
r_values = _sparse_matrix_mask_helper(input, mask_indices.contiguous());
}
at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values);
return output;
}
Tensor sparse_sparse_matmul_backward(
const Tensor& grad,
const Tensor& a,
const Tensor& b,
int64_t grad_order) {
/*
To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition
for dense tensors:
c = a @ b
then
a_grad = c_grad @ b^T
b_grad = a^T @ c_grad
So for sparse matrices we can use the following definition:
if grad_order == 0:
a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a)
else:
b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b)
*/
TORCH_CHECK(
grad_order == 0 || grad_order == 1,
": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
if (grad_order == 0) {
auto a_grad = _sparse_sparse_matmul(grad, b.t());
return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce());
}
auto b_grad = _sparse_sparse_matmul(a.t(), grad);
return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce());
}
Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
auto flatten = [&](const Tensor & t) {
return t.transpose(dim, 0).contiguous().view({t.size(dim), -1});
};
auto unflatten = [&](const Tensor & t) {
return t.contiguous().view(transposed_sizes).transpose(dim, 0);
};
// renorm computes the norm over all dimensions except `dim`, which is why
// we need the flatten and unflatten business. TODO: simplify this when we
// add support for norm over multiple dimensions.
auto self_flat = flatten(self);
auto grad_flat = flatten(grad);
auto norm_flat = self_flat.norm(p, 1, true);
auto grad_output = (self_flat * grad_flat).sum(1, true);
auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true);
auto invnorm = (norm_flat + 1e-7).reciprocal();
auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb));
auto norm = unflatten(norm_flat.expand_as(self_flat));
// TODO: remove the detach once comparison ops no longer require grad
auto mask = Variable(norm < maxnorm).detach();
return at::where(mask, grad, grad_norm);
}
Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) {
auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
if (find_iter != repeats.cend()) {
return at::zeros(input_shape, grad.options());
}
const auto input_dims = input_shape.size();
int64_t num_unsqueezed = grad.dim() - input_dims;
for (int64_t i = 0; i < num_unsqueezed; ++i) {
grad = grad.sum(0, false);
}
at::DimVector grad_size, sum_dims;
for (size_t dim = 0; dim < input_dims; ++dim) {
int64_t repeat = repeats[dim + num_unsqueezed];
// Reshape gradient (repeat > 1)
// Index: [..., dim , ...] [..., dim , dim+1 , ...]
// Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
// The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor.
// Then, sum up gradients over repeated tensors along 'dim', and reduce shape
// from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize').
// Example:
// Size(3, 2) Size(6, 2)
// [[v1_0, v1_1],
// [v1_2, v1_3],
// [[v0, v1], repeat(2, 1) [v1_4, v1_5],
// [v2, v3], -------------> [v2_0, v2_1],
// [v4, v5]] [v2_2, v2_3],
// [v2_4, v2_5]]
//
// input grad (3, 2) reshape (2, 3, 2) output grad (6, 2)
// [[[g1_0, g1_1], [[g1_0, g1_1],
// [g1_2, g1_3], [g1_2, g1_3],
// [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5],
// [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1],
// [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3],
// [g2_2, g2_3], [g2_4, g2_5]]
// [g2_4, g2_5]]]
// If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then
// sum over 'dim+1'. The gradient for input is not correctly aligned with input.
// Example:
// input grad (3, 2) reshape (3, 2, 2) output grad (6, 2)
// [[[g1_0, g1_1],
// [g1_2, g1_3]], [[g1_0, g1_1],
// [g1_2, g1_3],
// [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5],
// [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1],
// [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3],
// [[g2_2, g2_3], [g2_4, g2_5]]
// [g2_4, g2_5]]]
if (repeat != 1) {
grad_size.push_back(repeat);
sum_dims.push_back(grad_size.size() - 1);
}
// Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1)
grad_size.push_back(input_shape[dim]);
}
// One-time Reshape & Sum
// Reshape gradient to grad_size:
// 1. If repeat equals to 1, append input size at that dimension,
// 2. If repeat is larger than 1, append both repeat and input size at that dimension.
// Sum over all "repeat" dimensions from sum_dims:
// Example:
// Input Size (2, 3, 4, 5)
// repeat [4, 1, 9, 3]
// output/grad Size (8, 3, 36, 15)
// grad_size [4, 2, 3, 9, 4, 3, 5]
// sum_dims [0, 3, 5]
// When repeat 1 time over all original dimensions, the empty sum_dims will reduce
// the whole grad tensor into a scalar rather than keeping original dimensions.
if (!sum_dims.empty()) {
grad = grad.reshape(grad_size);
grad = grad.sum(sum_dims);
}
return grad;
}
// p1m == 1 - p
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
if (grad.requires_grad()) {
// Use autograd-friendly backward if double backward is required
return grad * (mask.type_as(grad) * (1. / p1m));
} else {
return at::_masked_scale(grad, mask, 1. / p1m);
}
}
Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
if (input.is_cuda()) {
auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan()));
return mask * (grad / mask.sum());
} else {
auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum());
}
}
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
}
Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
if (self.dim() == 0) {
return var_backward(grad, self, unbiased);
}
if (!keepdim && self.dim() > 1) {
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
}
return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
}
Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) {
return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, unbiased);
}
Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, dim, unbiased, keepdim);
}
Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) {
return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim);
}
Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int numel) {
return grad.expand(sizes) / numel;
}
Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, IntArrayRef dim, bool unbiased, bool keepdim, bool is_std) {
Tensor grad;
if (grads[0].defined()) {
grad = is_std ? std_backward(r1, grads[0], self, dim, unbiased, keepdim) : var_backward(grads[0], self, dim, unbiased, keepdim);
}
if (grads[1].defined()) {
Tensor mean_grad = mean_backward(grads[1], self.sizes(), dim, keepdim);
grad = grads[0].defined() ? grad + mean_grad : mean_grad;
}
return grad;
}
Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, bool unbiased, bool is_std) {
Tensor grad;
if (grads[0].defined()) {
grad = is_std ? std_backward(r1, grads[0], self, unbiased) : var_backward(grads[0], self, unbiased);
}
if (grads[1].defined()) {
Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel());
grad = grads[0].defined() ? grad + mean_grad : mean_grad;
}
return grad;
}
Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
int64_t numel = 1;
for (auto size : sizes) {
numel *= size;
}
auto mask_selected = grad.masked_select(mask);
auto diff_nelem = numel - mask_selected.numel();
if (diff_nelem > 0) {
// because mask_selected returns a 1-d tensor with size of masked elements that are 1,
// we need to fill out the rest with zeros then reshape back to tensor2's size.
auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
}
return mask_selected.view(sizes);
}
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// cf. Iain Murray (2016); arXiv 1602.07527
// This gradient is symmetric, and not triangular.
// Cholesky additionally assumes that the input is symmetric, which is a subspace of
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
// elements. We resolve this by taking the gradient of the functionally independent
// elements of the matrix (i.e., the lower triangular portion of the input) and then
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
// the cholesky operation. The motivation behind this choice is that symmetric gradient
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
if (upper) {
L = L.transpose(-1, -2).conj();
grad = grad.transpose(-1, -2).conj();
}
auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
auto phi = at::matmul(L.transpose(-1, -2).conj(), grad);
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse);
return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5); // Symmetrizing the gradient
}
Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
Tensor grad_L;
if (grad.defined()) {
Tensor common_term = grad + grad.transpose(-2, -1);
common_term = at::matmul(inverse, at::matmul(common_term, inverse));
if (upper) {
grad_L = -at::matmul(L, common_term);
} else {
grad_L = -at::matmul(common_term, L);
}
} else {
grad_L = at::zeros({1}, L.options()).expand_as(L);
}
return grad_L;
}
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// it's possible some of the grads are not defined (represents tensors of all 0s).
// Since at::cat can't handle those, let's define them
std::vector<Tensor> grads_all_defined(grads.size());
for (size_t j = 0; j < grads.size(); ++j) {
if (grads[j].defined()) {
grads_all_defined[j] = grads[j];
} else {
auto length = split_sizes[j];
auto grad_size = sizes.vec();
grad_size[dim] = length;
grads_all_defined[j] = at::zeros(grad_size, options);
}
}
auto ret = at::cat(grads_all_defined, dim);
return ret;
}
Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
dim = at::maybe_wrap_dim(dim, sizes.size());
int64_t dim_size = sizes[dim];
int64_t num_splits = grads.size();
std::vector<int64_t> split_sizes(num_splits, split_size);
split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
return split_with_sizes_backward(grads, split_sizes, dim, sizes, options);
}
Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
AT_ASSERT(indices.dim() >= dim);
auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
size.push_back(-1);
auto indices_view = indices.view(size);
const auto memory_format = indices.suggest_memory_format();
return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes());
}
Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
auto& gO = grad_output;
auto input_size = input.size(dim) / 2;
auto first_half = input.narrow(dim, 0, input_size);
auto second_half = input.narrow(dim, input_size, input_size);
auto sig_second_half = second_half.sigmoid();
auto one_sub_sig_second_half = 1 - sig_second_half;
auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
auto ggI_first_half = grad.narrow(dim, 0, input_size);
auto ggI_second_half = grad.narrow(dim, input_size, input_size);
auto ggI_second_half_times_first_half = ggI_second_half * first_half;
auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
return at::cat({gI_first_half, gI_second_half}, dim);
}
Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
if (dim < 0) dim += input.dim();
auto sizes = input.sizes().vec();
sizes[dim] /= 2;
auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
}
Tensor infinitely_differentiable_silu_backward(
const Tensor& grad_output,
const Tensor& input) {
const Tensor sigmoid = input.sigmoid();
return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
}
Tensor infinitely_differentiable_logit_backward(
const Tensor& grad,
const Tensor& self,
c10::optional<double> eps) {
if (eps) {
const double lo = eps.value();
const double hi = 1.0 - lo;
return at::where(
at::logical_and(self >= lo, self <= hi),
grad / (self * (1.0 - self)),
at::zeros({}, self.options()));
} else {
return at::where(
at::logical_and(self >= 0.0, self <= 1.0),
grad / (self * (1.0 - self)),
at::empty({}, self.options())
.fill_(std::numeric_limits<double>::quiet_NaN()));
}
}
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
if (reduction == at::Reduction::Mean) {
return result.mean();
} else if (reduction == at::Reduction::Sum) {
return result.sum();
}
return result;
}
// Compute derivatives for targets.
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
Tensor grad_target;
if (!log_target) {
grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
}
else {
grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
}
if (reduction == at::Reduction::Mean) {
grad_target.div_(target.numel());
}
return grad_target;
}
Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional<Tensor>& weight, const c10::optional<Tensor>& pos_weight, int64_t reduction) {
Tensor grad_target;
if (isDefined(pos_weight)) {
grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output);
} else {
grad_target = self.mul(-grad_output);
}
if (isDefined(weight)) {
grad_target.mul_(*weight);
}
if (reduction == at::Reduction::Mean) {
grad_target.div_(target.numel());
}
return grad_target;
}
Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
auto z = input.sigmoid();
return grad * (z - 1) * z;
}
Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto gO = grad_output;
auto ggI = grad;
auto ggI_output = ggI * output;
auto ggI_out_sum = ggI_output.sum(dim, true);
auto ggI_out_sum_output = ggI_out_sum * output;
auto gO_out_sum = (gO * output).sum(dim, true);
// gI calculation
auto gI_t0 = ggI_output * (gO - gO_out_sum);
auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
auto gI_t2 = ggI_out_sum_output * gO;
auto gI_t3 = ggI_out_sum_output * gO_out_sum;
return gI_t0 - gI_t1 - gI_t2 + gI_t3;
}
Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto z = output.exp();
return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
}
// NOTE: [How to write vmap-compatible backward formulas]
//
// See NOTE: [vmap-incompatible in-place operations] for what it means for an
// in-place operation to be incompatible with vmap.
//
// If an in-place operation used in a backward formula is vmap-incompatible,
// then as developers we have the following options:
//
// - If the in-place operation directly followed the creation of a tensor with
// a factory function like at::zeros(...), we should replace the factory with a
// corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
// propagates the batch dims to the resulting tensor.
// For example:
// Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
// After: grad.new_zeros(input.sizes()).copy_(grad)
//
// - If the in-place operation followed some sequence of operations, if the
// we want to be able to vmap over the backward formula as-is (this is
// usually the case for simple (<15loc) backward formulas), then use
// inplaceIsVmapCompatible to guard the operation. For example:
// c = a * b
// Before: c.mul_(grad)
// After: c = at::inplaceIsVmapCompatible(c, grad) ? c.mul_(grad) : c * grad
//
// - If we don't want to vmap directly over the backward formula (e.g., if the
// backward formula is too complicated or has a lot of vmap-incompatible
// operations, then register the backward formula as an operator and eventually
// write a batching rule for it.
Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
auto eps = 1e-12;
auto inp_pl_eps = input + eps;
auto one_m_inp_pl_eps = 1 - input + eps;
// gradient wrt input
auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
if (at::inplaceIsVmapCompatible(gI, grad)) {
gI *= (grad * grad_output);
} else {
gI = gI * (grad * grad_output);
}
if (isDefined(weight)) {
gI *= *weight;
}
if (reduction == at::Reduction::Mean) {
return gI / input.numel();
} else if (reduction == at::Reduction::Sum) {
return gI.sum();
}
return gI;
}
Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
auto eps = 1e-12;
// gradient wrt grad_output
auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
if (at::inplaceIsVmapCompatible(ggO, grad)) {
ggO *= grad;
} else {
ggO = ggO * grad;
}
if (isDefined(weight)) {
ggO *= *weight;
}
if (reduction == at::Reduction::Mean) {
return ggO / input.numel();
} else if (reduction == at::Reduction::Sum) {
return ggO.sum();
}
return ggO;
}
Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto output = l1_loss_backward(grad, input, target, at::Reduction::None);
if (reduction == at::Reduction::Mean) {
return output.mean();
} else if (reduction == at::Reduction::Sum) {
return output.sum();
}
return output;
}
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
// special case to protect against a divide-by-zero.
if (beta == 0) {
return at::zeros(grad.sizes(), grad.options());
}
auto d = (input - target).abs();
auto grad_input = grad * (d < beta).type_as(grad) / beta;
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
if (reduction == at::Reduction::None) {
return smooth_l1_loss_backward(grad, input, target, reduction, beta);
}
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta);
return (r * grad).sum();
}
Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
auto grad_input = 2 * grad;
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == at::Reduction::None) {
return mse_loss_backward(grad, input, target, reduction);
}
auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto z = (input * -target).exp();
auto zplus1 = z + 1;
auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
if (reduction == at::Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == at::Reduction::None) {
return soft_margin_loss_backward(grad, input, target, reduction);
}
auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) {
auto x = (input * beta);
return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta;
}
// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
// x = torch.randn(15)
// x.as_strided([3, 3], [1, 0]) # "expand" case
// x.as_strided([3, 3], [2, 1]) # "size too large" case
// x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6
// # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
// 0. ??? (optimization step. we will talk about this later)
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor.
// 3. ??? (fix for input tensor with overlapping memory. we will talk about
// this later)
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor does't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
// x = torch.randn(3)
// y = x.as_strided([3, 3], [1, 0]) # "expand" case
// # size [ 3, 3]
// # stride [ 1, 0]
// y.backward() # step (1): contiguous storagte tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): since `y` has overlapping memory, index_add grad
// into `s` basing on `y`'s geometry, i.e.,
// s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
// s = [ 3, 3, 3]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 3, 3, 3]
// grad_input = s.as_strided(x.size(), x.stride())
// = s.as_strided([3], [1])
// = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
// t = torch.randn(3)
// x = t.expand(3, 3) # input with overlapping memory
// # size [3, 3]
// # stride [0, 1]
// y = x.as_strided([1], [1]) # contiguous output
// # size [1]
// # stride [1]
// y.backward() # step (1): contiguous storage tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): scatter grad into `s` basing on `y`'s geometry
// s = [ 1, 0, 0]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 1, 0, 0]
// grad_input = s.as_strided([3, 3], [0, 1])
// = s.as_strided([3, 3], [0, 1])
// = [[ 1, 0, 0],
// [ 1, 0, 0],
// [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
// t = torch.randn(1)
// x = t.expand(2)
// y = x.sum()
// y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
// gy = 1
// gx = [ 1, 1] # SumBackward: torch.ones_like(x)
// gt = [ 2] # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
// gy = 1
// gx = [ 2, 2] # Because the backward considers the fact that the input `x`
// # is already expanded.
// gt = [ 2] # Layout-aware backward of expand is just a slicing because
// # the previous backward should have already taken care of
// # strides and made sure that gradients are the same along the
// # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
// 1. storing tensor geometries of every input tensor for backward
// 2. depending on input geometry, the gradient computed from backward change
// 3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
// 1. "Scatter" each value of `input` into a "storage" using storage location
// computed from the value's index in `input`, `input.size()` and
// `input_stride`, but if N values end up in the same location, the value
// is average of those N values (they will be the same value anyways).
//
// Formal description:
// Denote the set of all input indices that pointing to the same storage
// location `storage[n]` as `S(n)`, i.e.,
//
// S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
//
// where `<x, y>` is the dot product between `x` and `y`.
//
// Then, the process is:
//
// storage[n] = Avg { S(n) }
//
// Note that all values in `S(n)` are the same (they point to the same
// memory location anyways, so this step doesn't change anything, but
// effectively avoids having the denpendency on the layout of `input`.
// I.e., the result holds fixed regardless of the layout of `input`, as
// long as `input_stride` is fixed.
//
// NOTE: for forward pass, we can equivalently simply selet any one of
// `S(n)` as `storage[n]`. However, cosnidering this as an average
// operation makes backward easier (so all values in set
// `{ grad_input[i] : i in S(n) }` are the same, and it can use the
// same geometry as input).
// 2. As usual, return the as_strided view of `storage` using required output
// `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
// .... (scatter gradients into the storage tensor using output geometry)
// 3. For all storage location n, `storage[n] /= |S(n)|`.
// .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
// Add step (0): For all output dimension `d` with output stride 0, sum the
// gradients along dimension `d` (don't keepdim), and remove
// dimension `d` from output size and stride.
// (An optimization for "expand" cases so we may avoid step (3))
// Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
// 0. For all output dimension `d` with output stride 0, sum the gradients
// along dimension `d` (don't keepdim), and remove dimension `d` from
// output size and stride.
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor `storage`.
// 3. If input tensor has overlapping memory,
// For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
// number of indices in input geometry pointing to the same storage
// location `i` (i.e., `|S(i)|` in equations above).
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
// 0. Return [ pass check ] if any dimension has size 0
// 1. Ignore all dimensions that have size 1
// 2. If no remaining dimensions, return [ pass check ]
// 3. Sort the remaining dimensions according to the strides decreasingly
// 4. Check that for each dimension k,
//
// stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
// That is equivalent to, after reordering the dimensions so strides are
// in decreasing order, checking that stride of each dimension is larger
// than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
// alias: Obviously preserves
//
// expand: All changed dimensions are removed in step (1)
//
// view: Consider the input dimensions as grouped into consecutive
// dimension "blocks", where dimensions are contiguous in each one.
// one. view only works when the output dimensions can also be
// grouped into the same consecutive blocks of same ordering.
//
// NB: this means that the number of elements and stride of the
// last dimension in each block is the same in input and
// output. (**)
//
// Notation:
// Consider a single such block B,
// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
// start--^^^^ ^^^^^^^^^^^^--end
// Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
// We first show that in a tensor (i.e., input) satisfies the
// invariant, after sorting, the dimensions within each block
// still remain consecutive. (***)
//
// After removing dimensions of size 1, the dimensions within a
// block is already sorted by strides in descending order. So
// sorting all dimensions will not change the relative ordering
// among them.
//
// Assume that some block B is not consecutive after sorting,
// i.e., there exists a dimension d between B[0] and B[-1] in
// sorted order.
//
// By (*), we know that
// stride[B[0]]
// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
// < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d]
// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
// <= \sum{j > B[0]} (size[j] - 1) * stride[j],
//
// where the first < comes from sorting and
// the second <= comes from the fact that dimension d
// exists after step (1) and
// thus must have size greater
// than 1
// the third <= comes from the fact that each term in
// the sum is non-negative
//
// Then we have a countradiction as the invariant must not be
// satisfied at B[0]. So the original proposition is true.
//
// Now that we established the above claim (***), we consider the
// view operation as first sorting the dimensions (i.e., blocks),
// apply the original view (since it only cares dimensions being
// consecutive and contiguous withtin each block), and then undo
// the sort.
//
// Consider a single block B in the output,
// ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
// start--^^^^ ^^^^^^^^^^^^--end
//
// By (*), we know that for all i
// stride[i] = stride[B[-1]] +
// \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
//
// Then the invariant is obviously satisfied at every dimension
// in this block if it is satisfied at dimnesion B[-1]. It only
// remains to show that it is satisfied at the last dimension in
// each block.
//
// Since the same blocks are present in both input and output
// with the same ordering, we will abuse the notation in the
// following statements.
//
// By (*), we know that the following holds for both input and
// output, for any block B:
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
// = \sum_{block B' after B} numel(B') * stride[B'[-1]].
// ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
// By (**), we know that, this quantity in the above equation
// remains the same in input and output. So both
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// and
// stride[B[-1]]
// are the same in input and output.
//
// These two quantities are exactly the LHS and RHS of the
// invariant inequality. Since by assumption the invariant is
// satisfied in input at B[-1], it is also satisfied in output at
// B[-1]. This concludes the proof.
//
// squeeze: Special case of view
//
// unsqueeze: Special case of view
//
// slice: Consider slicing dimension i with step = k >= 1.
//
// Let stride' and size' be the output strides and sizes. We have
//
// stride'[i] = k * stride[i]
// size'[i] <= floor(size[i] / k)
//
// If size'[i] = 1, invariant is obviously satisfied as we are
// just removing a dimension (afte step (1)).
//
// Assume size'[i] > 1.
//
// By assumption, the invariant is satisfied at every dimension
// in input.
//
// For any dimension j, if stride[j] > stride[i], we have
// stride'[j] = stride[j]
// > (size[i] - 1) * stride[i]
// = (size[i] / k * k - 1) * k * stride[i] / k
// = (size[i] / k - 1 / k) * stride'[i]
// >= (size'[i] - 1 / k) * stride'[i]
// >= stride'[i].
//
// If stride[j] < stride[i], we have
// stride'[j] = stride[j] < stride[i] <= stride'[i].
//
// So the sorting order remains unchanged after slice.
//
// Since
// (size'[i] - 1) * stride'[i]
// = (floor(size[i] / k) - 1) * k * stride[i]
// <= (size[i] / k - 1) * k * stride[i]
// = (size[i] - k) * stride[i]
// <= (size[i] - 1) * * stride[i],
// the term from this dimension i in the invariant inequality at
// other dimensions can only decrease after slice. So the
// invariant is preserved.
//
// narrow: Special case of slice
//
// select: narrow + squeeze
//
// permute: Sorting makes permutation of dimensions irrelevant
//
// transpose: Sorting makes swapping dimensions irrelevant
//
// diagonal: Effectively merging two dimensions i and j into a new
// dimension k s.t.
// stride'[k] = stride[i] + stride[j]
// size'[k] <= min(size[i], size[j]),
// where stride and size are on the input, and stride' and size'
// are on the output.
//
// Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
// then this is unsqueeze on that dimension.
//
// WLOG, say stride[i] >= stride[j].
//
// Each dimension d in input with stride[d] > stride[j] has
// stride'[d] = stride[d]
// > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
// >= stride[i] + stride[j]
// = stride[k].
// So, considering the sorted dimensions, this is effectively
// removing i, and replacing j with k.
//
// For dimensions d with stride[i] < stride[d] < stride[j], the
// term from dimension i is removed in the invariant inequality.
// For dimensions d with stride[d] > stride[j], we have
// (size'[k] - 1) * stride'[k]
// <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
// <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
// so the term from i and j in the invariant can only decrease.
//
// So this is generally relaxing the constraint, and thus it
// preserves it.
// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) {
if (sizes.size() > 0) {
std::vector<std::size_t> argsort(sizes.size());
std::iota(argsort.begin(), argsort.end(), 0);
std::sort(argsort.begin(), argsort.end(),
[&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });
int64_t max_index_in_slice = 0;
for (auto i : argsort) {
auto stride_ = strides[i];
if (stride_ <= max_index_in_slice) {
return true;
}
max_index_in_slice += stride_ * (sizes[i] - 1);
}
}
return false;
}
// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
// Helper for as_strided_backward
static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
int64_t storage_size = storage_offset + 1;
int64_t dim = sizes.size();
for (int64_t i = 0; i < dim; i++) {
auto size_i = sizes[i];
if (size_i == 0) {
return storage_offset;
}
storage_size += (size_i - 1) * strides[i];
}
return storage_size;
}
// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_) {
// For output geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// reduce grad on expanded dims (stride=0, size>1)
// Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
auto odim = grad.dim();
std::vector<int64_t> out_sizes_, out_strides_;
out_sizes_.reserve(odim);
out_strides_.reserve(odim);
for (int64_t i = odim - 1; i >= 0; i--) {
auto size_i = sizes[i];
auto stride_i = strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i == 1) {
grad = grad.squeeze(i);
} else if (stride_i == 0) {
grad = grad.sum(i, false);
} else {
out_sizes_.insert(out_sizes_.begin(), size_i);
out_strides_.insert(out_strides_.begin(), stride_i);
}
}
// Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
// For input geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto idim = input_geometry.dim();
IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
std::vector<int64_t> inp_sizes_, inp_strides_;
inp_sizes_.reserve(idim);
inp_strides_.reserve(idim);
for (int64_t i = idim - 1; i >= 0; i--) {
auto size_i = inp_sizes[i];
auto stride_i = inp_strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i != 1) {
inp_sizes_.insert(inp_sizes_.begin(), size_i);
inp_strides_.insert(inp_strides_.begin(), stride_i);
}
}
// Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
// Rest of this function implements
// Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// TODO: Raise if not all output values are visible in input geometry.
// Technically speaking, if you treat those values as constants, not
// raising is fine, and mathematically correct. However, these values
// really are contained in some base tensor, and by treating them as
// constants we are ignoring this tight dependency. Therefore, it is
// more sensible to raise here.
// Step (1): create underlying tensor as "storage"
auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
auto out_effective_offset = storage_offset - shared_offset;
auto base_size = std::max(
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
_min_storage_size(out_sizes_, out_strides_, out_effective_offset)
);
auto storage = grad.new_zeros({base_size});
// prepare indices tensor if we will do index_add_ later
c10::optional<at::Tensor> flatten_full_indices;
if (inp_maybe_overlap || out_maybe_overlap) {
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
}
// Step (2): use output geometry to scatter gradients into storage
if (out_maybe_overlap) {
auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
} else {
// assume that new tensors have 0 storage offset
storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
}
// Step (3): if input tensor has overlapping memory, divide scattered gradient
// at storage[i] by the number of times i shows up in input geometry
if (inp_maybe_overlap) {
auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
storage.div_(count); // this will give nan outside visible range
}
// Step (4): return as_strided view of the storage tensor with input geometry
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
}
auto recip = (self * self + other * other).reciprocal();
return std::tuple<Tensor,Tensor>{
output_mask[0] ? grad * other * recip : Tensor(),
output_mask[1] ? grad * -self * recip : Tensor() };
}
// TODO: Seriously consider writing the derivative formulas for
// each output separately; there is not all that much sharing
// of computation going on here.
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
const Tensor & grad_grad_input,
const Tensor & grad_grad_weight,
const Tensor & grad_out,
const Tensor & input_,
const Tensor & weight_) {
if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) {
return std::tuple<Tensor, Tensor, Tensor>(Tensor(), Tensor(), Tensor());
}
auto input = input_.contiguous();
auto weight = weight_.contiguous();
// Zero-fill undefined grads (TODO: do this more efficiently)
auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto positive_mask = (input > 0).type_as(ggI);
auto nonpositive_mask = (input <= 0).type_as(ggW);
// Explanation: Let input be i, weight be w, grad_output be gO.
// f(i, w) = i if i > 0
// = w * i if i <= 0
// gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0
// = gO * w if i <= 0 = gO * i if i <= 0
// The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.
if (weight.numel() == 1) {
// from PReLU.forward: num_parameters == 0 is used indicate that a
// single weight is shared among all input channels.
// this is a little tricky because PReLU currently doesn't take a shape so the weight may be
// 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
// and tensor are merged). So, use weight and ggW as 0-dim in this case.
bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;
auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
return std::tuple<Tensor, Tensor, Tensor>(
ggO,
ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
(ggI * gO * nonpositive_mask).sum().expand_as(weight)
);
} else {
// Expand ggW to match size of ggI; a simple expand doesn't work because
// ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example,
// let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1)
// so the expand succeeds.
auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
auto ggW_expanded = ggW;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
ggW_expanded = ggW_expanded.unsqueeze(1);
}
ggW_expanded = ggW_expanded.expand_as(ggI);
auto gI = ggW_expanded * gO * nonpositive_mask;
auto gW = ggI * gO * nonpositive_mask;
if (input.dim() > 1) {
gW = gW.sum(0);
}
while (gW.dim() > 1) {
gW = gW.sum(1);
}
Tensor ggO;
if (gO.requires_grad()) {
// expand weight as input as in ggW/ggI above
auto weight_expanded = weight;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
weight_expanded = weight_expanded.unsqueeze(1);
}
weight_expanded = weight_expanded.expand_as(input);
auto mask = positive_mask + nonpositive_mask * weight_expanded;
ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
}
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
}
}
// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
TORCH_CHECK(compute_uv,
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
auto m = self.size(-2);
auto n = self.size(-1);
auto k = sigma.size(-1);
auto gsigma = grads[1];
auto u = raw_u;
// Currently torch.svd for complex dtypes returns the conjugate of V,
// while the backward formula is derived with just V (without the conjugation)
// therefore here we need to conjugate the V output of SVD and grads[2].
// Once https://github.com/pytorch/pytorch/issues/45821 is resolved
// extra .conj(), that are marked below in the code, shall be removed.
auto v = raw_v.conj(); // TODO: remove .conj()
auto gu = grads[0];
auto gv = grads[2].conj(); // TODO: remove .conj()
if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj()
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
if (gv.defined()) {
gv = gv.narrow(-1, 0, k);
}
}
auto vh = v.conj().transpose(-2, -1);
Tensor sigma_term;
if (gsigma.defined()) {
gsigma = gsigma.to(self.dtype());
// computes u @ diag(gsigma) @ vh
sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh);
} else {
sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// in case that there are no gu and gv, we can avoid the series of kernel
// calls below
if (!gv.defined() && !gu.defined()) {
return sigma_term;
}
auto uh = u.conj().transpose(-2, -1);
auto im = at::eye(m, self.options());
auto in = at::eye(n, self.options());
auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype());
auto sigma_sq = sigma.pow(2);
auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
// The following two lines invert values of F, and fills the diagonal with 0s.
// Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
// first to prevent nan from appearing in backward of this function.
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F = F.pow(-1);
Tensor u_term, v_term;
if (gu.defined()) {
auto guh = gu.conj().transpose(-2, -1);
u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat));
if (m > k) {
u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv));
}
u_term = at::matmul(u_term, vh);
} else {
u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (gv.defined()) {
auto gvh = gv.conj().transpose(-2, -1);
v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh));
if (n > k) {
v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh)));
}
v_term = at::matmul(u, v_term);
} else {
v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// for complex-valued input there is an additional term
// https://giggleliu.github.io/2019/04/02/einsumbp.html
// https://arxiv.org/abs/1909.02659
if (self.is_complex() && gu.defined()) {
// computes L = Identity.mul(uh @ gu)
Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1);
L = L - L.conj().transpose(-2, -1);
Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh);
return u_term + sigma_term + v_term + imag_term;
}
return u_term + sigma_term + v_term;
}
// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation"
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, const Tensor& lambda, const Tensor& v) {
// This gradient only works for real eigenvalues at the moment.
TORCH_CHECK(eigenvectors,
"eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)");
auto zeros = at::zeros({1}, lambda.options());
TORCH_CHECK(
at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros),
"eig_backward: Backward calculation does not support complex eigenvalues at the moment.");
auto glambda = grads[0];
auto gv = grads[1];
auto vt = v.transpose(-2, -1);
Tensor result;
// contribution from the eigenvectors
if (gv.defined()) {
auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1);
auto hm = rlambda.transpose(-2,-1) - rlambda;
hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
hm.pow_(-1.0);
auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v;
auto B = hm * at::matmul(vt, gvortho);
auto A = at::matmul(B, vt);
std::tie(result, std::ignore) = at::solve(A, vt);
}
// contribution from eigenvalues
if (glambda.defined()) {
auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt;
auto A = at::matmul(v, grlambda);
auto vvt = at::matmul(v, vt);
if (result.defined()) {
Tensor result1;
std::tie(result1, std::ignore) = at::solve(A, vvt);
result = result.add(result1);
}
else {
std::tie(result, std::ignore) = at::solve(A, vvt);
}
}
return result;
}
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
// This gradient is symmetric, and not triangular.
// symeig operates only on symmetric inputs, which is a subspace of
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
// elements. We resolve this by taking the gradient of the functionally independent
// elements of the matrix (i.e., the lower triangular portion of the input) and then
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
// the symeig operation. The motivation behind this choice is that symmetric gradient
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
TORCH_CHECK(eigenvectors,
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
auto glambda = grads[0];
auto gv = grads[1];
auto vh = v.conj().transpose(-2, -1);
Tensor result;
if (gv.defined()) {
Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1);
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F.pow_(-1);
result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh));
} else {
result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (glambda.defined()) {
glambda = glambda.to(self.dtype());
// computes v @ diag(glambda) @ vh
Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh);
if (at::inplaceIsVmapCompatible(result, glambda_term)) {
result.add_(glambda_term);
} else {
result = result + glambda_term;
}
}
return result.add(result.conj().transpose(-2, -1)).mul_(0.5);
}
Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
std::string mode, const Tensor& q, const Tensor& r){
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. "
"Please use torch.linalg.qr(..., mode='reduced')");
auto square_deep_case_backward = [](const Tensor& grad_Q,
const Tensor& grad_R,
const Tensor& A,
const Tensor& Q,
const Tensor& R) -> Tensor {
// For square and deep (tall) case we refer:
// Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra.
// https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition)
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html
// Compute R grad_R^H
Tensor R_term;
if (grad_R.defined()) {
R_term = at::matmul(R, grad_R.conj().transpose(-2, -1));
} else {
// R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
// Compute grad_Q^H Q
Tensor Q_term;
if (grad_Q.defined()) {
Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q);
} else {
// Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor M = R_term - Q_term;
// Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity
Tensor M_tril = at::tril(M);
M = M_tril + M_tril.conj().transpose(-2, -1);
M.diagonal(0, -2, -1).mul_(0.5);
Tensor rhs_term;
if (grad_Q.defined()) {
rhs_term = grad_Q + at::matmul(Q, M);
} else {
rhs_term = at::matmul(Q, M);
}
// We want to compute: (rhs_term @ R^{-H})
// Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H
// Since R is upper triangular, we can do this using
// triangular_solve(rhs_term^H, R)^H
Tensor grad_A;
std::tie(grad_A, std::ignore) = at::triangular_solve(
rhs_term.conj().transpose(-2, -1),
R,
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
return grad_A.conj().transpose(-2, -1);
};
auto m = self.size(-2);
auto n = self.size(-1);
TORCH_CHECK(
((m <= n && (!reduced)) || reduced),
"The derivative of qr is not implemented when mode='complete' and nrows > ncols.");
auto grad_Q = grads[0];
auto grad_R = grads[1];
if (m >= n) {
return square_deep_case_backward(grad_Q, grad_R, self, q, r);
} else {
// For wide (m < n) input matrices A, partition A = [X|Y] and R = [U|V]
// X and U are square full rank matrices. We will partition grads,
// grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
// To obtain grad_X we reuse the gradient formula from the square case.
// Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
// where grad_Q_prime = grad_Q + Y @ grad_V^H
// and grad_Y = Q @ grad_V.
// Then concatenate grads to get grad_A = [grad_X | grad_Y].
auto Y = self.narrow(-1, m, n - m);
auto U = r.narrow(-1, 0, m);
Tensor grad_Y, grad_X, grad_V, grad_Q_prime;
if (grad_R.defined()) {
grad_V = grad_R.narrow(-1, m, n - m);
// reuse grad_R to store grad_U
grad_R = grad_R.narrow(-1, 0, m);
// grad_Q_prime starts with the value of Y @ grad_V^H
grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1));
} else {
// when grad_R is not defined then grad_V and grad_Q_prime
// get initialized with zeros
grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (grad_Q.defined()) {
// add the grad_Q term into grad_Q_prime when defined o/w is 0
grad_Q_prime = grad_Q_prime + grad_Q;
}
// Calculate grad_X using the helper. Grad_R contains the grad_U value
grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U);
grad_Y = at::matmul(q, grad_V);
// Concatenate grad_X and grad_Y to get grad_A.
return at::cat({grad_X, grad_Y}, -1);
}
}
// Invertible case is derived from Jacobi's formula, and also can be found at:
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1));
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (det.item<double>() == 0) {
return singular_case_backward(grad, self, det);
} else {
return nonsingular_case_backward(grad, self, det);
}
} else {
auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det));
c10::optional<Tensor> first_nonzero_det_index = nonzero_det_indices[0];
if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular)
return nonsingular_case_backward(grad, self, det);
}
auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0));
c10::optional<Tensor> first_zero_det_index = zero_det_indices[0];
if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular)
return singular_case_backward(grad, self, det);
}
Tensor grad_det = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_det.index_put_(/*indices=*/nonzero_det_indices,
/*value=*/nonsingular_case_backward(grad.index(nonzero_det_indices),
self.index(nonzero_det_indices),
det.index(nonzero_det_indices)));
// non-invertible case, uses SVD
grad_det.index_put_(/*indices=*/zero_det_indices,
/*value=*/singular_case_backward(grad.index(zero_det_indices),
self.index(zero_det_indices),
det.index(zero_det_indices)));
return grad_det;
}
}
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// logdet = \sum log(sigma)
auto gsigma = grad.unsqueeze(-1).div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (logdet.item<double>() != -INFINITY) {
return nonsingular_case_backward(grad, self);
} else {
return singular_case_backward(grad, self);
}
} else {
auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY));
c10::optional<Tensor> first_finite_logdet_index = finite_logdet_indices[0];
if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad, self);
}
auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY));
c10::optional<Tensor> first_neginf_logdet_index = neginf_logdet_indices[0];
if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad, self);
}
Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_logdet.index_put_(/*indices=*/finite_logdet_indices,
/*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices),
self.index(finite_logdet_indices)));
// non-invertible case, uses SVD
grad_logdet.index_put_(/*indices=*/neginf_logdet_indices,
/*value=*/singular_case_backward(grad.index(neginf_logdet_indices),
self.index(neginf_logdet_indices)));
return grad_logdet;
}
}
Tensor slogdet_backward(const Tensor& grad_logabsdet,
const Tensor& self,
const Tensor& signdet, const Tensor& logabsdet) {
auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// sigma has all non-negative entries (also with at least one zero entry)
// so logabsdet = \sum log(abs(sigma))
// but det = 0, so backward logabsdet = \sum log(sigma)
auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
};
auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
};
if (self.dim() == 2) {
if (signdet.item<double>() == 0) {
return singular_case_backward(grad_logabsdet, self);
} else {
return nonsingular_case_backward(grad_logabsdet, self);
}
} else {
auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet));
c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];
if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad_logabsdet, self);
}
auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];
if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad_logabsdet, self);
}
Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
/*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
self.index(nonzero_signdet_indices)));
// non-invertible case, uses SVD
grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
/*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
self.index(zero_signdet_indices)));
return grad_slogdet;
}
}
// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor & grad_x, const Tensor & grad_m,
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
Tensor grad_b, grad_a;
if (grad_x.defined() || grad_m.defined()) {
if (grad_x.defined()) {
grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj());
if (upper) {
grad_a = grad_a.triu((int) unitriangular);
} else {
grad_a = grad_a.tril(-((int) unitriangular));
}
}
}
if (!grad_a.defined()) {
grad_a = at::zeros({1}, a.options()).expand_as(a);
}
if (!grad_b.defined()) {
grad_b = at::zeros({1}, b.options()).expand_as(b);
}
if (output_mask[1] && grad_m.defined()) {
grad_a = grad_a.add(grad_m);
}
}
return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}
std::tuple<Tensor, Tensor> cholesky_solve_backward(
const Tensor& grad_x, const Tensor& self,
const Tensor& input2, const Tensor& result, const bool upper) {
Tensor grad_self, grad_input2;
if (grad_x.defined()) {
grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);
Tensor common_term = at::matmul(grad_self, result.conj().transpose(-2, -1));
common_term = common_term + common_term.conj().transpose(-2, -1);
if (upper) {
grad_input2 = -at::matmul(input2, common_term);
} else {
grad_input2 = -at::matmul(common_term, input2);
}
}
return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
}
Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) {
// Forward is C2R (onesided)
// Think of onesided C2R irfft as
// 1. fill the other half by conjugate symmetry
// 2. inverse C2C ifft
// 3. discard the complex dimension
// So backward is
// 1. R2C rfft (essentially add dummy complex dimension, and dft)
// 2. accumulate gradient by conjugate symmetry
// since rfft results follow conjugate symmetry, we only need to
// double some entries from onesided rfft results, i.e., the ones with
// their reflected indices also landing out of the onesided range. So
// consider the index of last dim:
// i. idx = 0.
// Reflected to (N - 0) % N = 0. Not doubled.
// ii 0 < idx < floor(N/2) (last).
// N > N - idx > ceil(N/2)
// Reflected to ()
// iii. idx = floor(N/2) = N/2 (last) when N even.
// Reflected to (N - N/2) % N = N/2. Not doubled.
// iv. idx = floor(N/2) = (N-1)/2 (last) when N odd.
// Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
// Therefore, needs to double
// idx = 1, 2, ..., N/2 - 1 when N even
// idx = 1, 2, ..., (N-1)/2 when N odd
// that is
// idx = 1, 2, ..., N - (floor(N/2) + 1)
// = 1, 2, ..., N - onesided_length
auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);
auto double_length = grad.size(dim.back()) - gI.size(dim.back());
if (double_length > 0) { // also covers case when signal size is zero
gI.narrow(dim.back(), 1, double_length).mul_(2);
}
return gI;
}
Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization,
bool onesided, int64_t last_dim_size) {
if (!onesided) {
return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
}
// Forward is R2C (onesided)
// Think of onesided R2C rfft as
// 1. view as complex numbers (fill complex dim with zeros)
// 2. C2C fft
// 3. discard half of results
// So backward is
// 1. fill the other half with zeros (with `zero_grad_shape` below)
// (C2C ifft only take twosided inputs so we need to fill here)
// 2. inverse C2C ifft
// 3. discard the complex dim
auto half_sizes = grad.sizes();
at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end());
const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size());
new_grad_shape[last_dim] = last_dim_size;
const auto zero_length = last_dim_size - grad.size(dim.back());
auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad;
if (zero_length > 0) {
complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad);
}
return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
}
// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
auto r = to_sum.sum(0, keepdim);
int64_t start_point_exclusive = keepdim ? 1 : 0;
for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
r = r.sum(dim, keepdim);
}
return r;
}
// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
if (src_expanded.sizes().size() == target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(0);
}
return src_expanded;
}
// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
return src_expanded.expand_as(target);
}
std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
const Tensor & input,
const c10::optional<Tensor> & gamma,
const Tensor & ggI,
const Tensor & ggG,
const Tensor & ggB,
const Tensor & gO,
const c10::optional<Tensor> & running_mean,
const c10::optional<Tensor> & running_var,
bool training,
double eps,
const c10::optional<Tensor> & save_mean,
const c10::optional<Tensor> & save_invstd,
std::array<bool,3> output_mask) {
bool affine = isDefined(gamma);
// TODO: Do we have a ScalarOrTensor type? Would such a thing exist?
Tensor gamma_expanded;
Tensor ggG_expanded, ggB_expanded;
if (affine) {
gamma_expanded = expand_as_dim1(*gamma, input);
if (ggG.defined()) {
ggG_expanded = expand_as_dim1(ggG, input);
}
if (ggB.defined()) {
ggB_expanded = expand_as_dim1(ggB, input);
}
} else {
gamma_expanded = at::ones({}, input.options());
}
// define some terms we will reuse
auto M = input.size(0);
for (auto s : input.sizes().slice(2)) {
M *= s;
}
// for half inputs, save_mean, save_invstd are float (ideally, we would cast
// everything else, but not now)
auto mu = unsqueeze_dim1(training ? toLegacyTensor(save_mean).to(input.scalar_type()) : toLegacyTensor(running_mean), input);
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
training ? toLegacyTensor(save_invstd).to(input.scalar_type())
: toLegacyTensor(running_var).add(Scalar(eps)).pow(-0.5),
input);
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
// calculate gI
auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
auto gO_sum = sum_exclude_dim1(gO);
Tensor gI;
if (ggI.defined() && training) {
auto ggI_sum = sum_exclude_dim1(ggI);
auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
(sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
}
// add contribution of gamma term to gI
Tensor gI_G_term;
if (affine && ggG.defined()) {
if (training) {
auto t0 = gO * sigma2_eps_neg_1_2;
auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
} else {
gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
}
}
// this is the first backward's grad_input
auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
return h0 * h1;
};
// calculate gG
Tensor gG;
if (affine && ggI.defined()) {
if (training) {
// gG is just the first backwards with the gamma term removed (then shaped properly)
gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
gG = sum_exclude_dim1(gG, false);
} else {
gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
}
}
// calculate ggO
Tensor ggO;
// contribution of input term
if (ggI.defined()) {
if (training) {
ggO = first_back_grad_input(ggI, gamma_expanded);
} else {
ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
}
}
if (ggG.defined()) {
auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
}
if (ggB.defined()) {
auto ggO_B_term = ggB_expanded;
ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
}
if (output_mask[1] && !gG.defined()) {
AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
}
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_layer_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
IntArrayRef normalized_shape,
double eps,
std::array<bool, 3> grad_input_mask) {
const int normalized_ndim = normalized_shape.size();
const auto input_shape = X.sizes();
const auto input_ndim = X.dim();
const int axis = input_ndim - normalized_ndim;
const int64_t M =
at::prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis);
const int64_t N =
at::prod_intlist(input_shape.cbegin() + axis, input_shape.cend());
Tensor dX;
Tensor dgamma;
Tensor dbeta;
const Tensor X_tensor = X.reshape({M, N});
const Tensor mean_tensor = mean.reshape({M, 1});
const Tensor rstd_tensor = rstd.reshape({M, 1});
const double s = 1.0 / static_cast<double>(N);
Tensor dY_tensor;
if (dY.defined()) {
dY_tensor = dY.reshape({M, N});
}
if (grad_input_mask[0]) {
Tensor gamma_tensor;
if (isDefined(gamma)) {
gamma_tensor = gamma->reshape({1, N});
}
Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
Tensor var;
Tensor dvar;
if (drstd.defined()) {
var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
dvar = -0.5 * rstd_cube * drstd.view({M, 1});
}
Tensor ds;
Tensor db;
if (dY.defined()) {
ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
: dY_tensor * X_tensor)
.sum(1)
.unsqueeze_(-1);
db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
.sum(1)
.unsqueeze_(-1);
const Tensor& a = rstd_tensor;
const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
if (isDefined(gamma)) {
dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
} else {
dX = a * dY_tensor + b * X_tensor + c;
}
if (dmean.defined() && drstd.defined()) {
dX += var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
{1},
false,
true,
false);
}
dX = dX.reshape_as(X);
} else if (dmean.defined() && drstd.defined()) {
dX = var_std_mean_backward(
{dvar, dmean.view({M, 1})},
X_tensor,
var,
mean_tensor,
{1},
false,
true,
false)
.reshape_as(X);
}
}
if (grad_input_mask[1] && dY.defined()) {
dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
.sum(0)
.reshape_as(toLegacyTensor(gamma));
}
if (grad_input_mask[2] && dY.defined()) {
dbeta = dY_tensor.sum(0).reshape_as(toLegacyTensor(gamma));
}
return std::make_tuple(dX, dgamma, dbeta);
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
const Tensor& dY,
const Tensor& dmean,
const Tensor& drstd,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const c10::optional<Tensor>& gamma,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
double eps,
std::array<bool, 3> grad_input_mask) {
const int64_t G = group;
const int64_t D = C / G;
const double s = 1.0 / static_cast<double>(D * HxW);
Tensor dX;
Tensor dgamma;
Tensor dbeta;
const Tensor X_tensor = X.reshape({N, G, D, HxW});
const Tensor mean_tensor = mean.reshape({N, G, 1, 1});
const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1});
Tensor dY_tensor;
Tensor ds;
Tensor db;
if (dY.defined()) {
dY_tensor = dY.reshape({N, G, D, HxW});
ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
db = dY_tensor.sum(3).unsqueeze_(-1);
}
if (grad_input_mask[0]) {
Tensor gamma_tensor;
if (isDefined(gamma)) {
gamma_tensor = gamma->reshape({1, G, D, 1});
}
const Tensor var =
((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
Tensor dvar;
if (drstd.defined()) {
dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1});
}
if (dY.defined()) {
const Tensor a =
isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
.unsqueeze_(-2);
Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
.unsqueeze_(-2);
b = (c * mean_tensor - b) * rstd_cube * s;
c = -b * mean_tensor - c * rstd_tensor * s;
dX = a * dY_tensor + b * X_tensor + c;
if (dmean.defined() && drstd.defined()) {
dX += var_std_mean_backward(
{dvar, dmean.view({N, G, 1, 1})},
X_tensor,
var,
mean_tensor,
{2, 3},
false,
true,
false);
}
dX = dX.reshape_as(X);
} else if (dmean.defined() && drstd.defined()) {
dX = var_std_mean_backward(
{dvar, dmean.view({N, G, 1, 1})},
X_tensor,
var,
mean_tensor,
{2, 3},
false,
true,
false)
.reshape_as(X);
}
}
if (grad_input_mask[1] && dY.defined()) {
dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toLegacyTensor(gamma));
}
if (grad_input_mask[2] && dY.defined()) {
dbeta = db.sum(0).reshape_as(toLegacyTensor(gamma));
}
return std::make_tuple(dX, dgamma, dbeta);
}
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
IntArrayRef sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) {
Tensor grad_i1, grad_i2, grad_i3;
if (grad_out.defined()) {
if (grad_mask[0])
grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
if (grad_mask[1])
grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
if (grad_mask[2])
grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
}
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}
Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
if (self.is_sparse()) {
AT_ERROR(
"log1p of a sparse tensor is made to be non-differentiable since ",
"local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
"Use a different mathematical operation which preserves sparsity of gradients, ",
"or report a bug if you think this is an error.");
}
return grad / (self + 1).conj();
}
Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) {
// TODO: improve this backward by writing a kernel (maybe)
auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out;
auto full_size = sparse_grad_out.sizes();
auto flattened_grad_shape = values_shape.vec();
flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0)));
auto flattened_dense_grad = dense_grad.view(flattened_grad_shape);
auto flattened_indices = at::sparse::flatten_indices(indices, full_size);
return flattened_dense_grad.index_select(0, flattened_indices);
}
// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
auto negated_pad = pad.vec();
std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
return at::constant_pad_nd(grad, negated_pad, 0);
}
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
// since first backward takes care of scaling by frequency,
// we don't need to worry about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));
// reshape gradient as per the shape of indices
auto size = indices.sizes().vec();
size.push_back(-1);
if (padding_idx >= 0) {
gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
}
return gg_weight.view(size);
}
Tensor index_backward(Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const Tensor& grad) {
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}
Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
if (zero_infinity) {
return at::where(
loss.unsqueeze(0).unsqueeze(2) == 0,
at::zeros({0}, raw_grad.options()),
raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
} else {
return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
}
}
bool any_variable_defined(variable_list& variables) {
for (auto variable : variables) {
if (variable.defined()) {
return true;
}
}
return false;
}
} // namespace details
} // namespace generated
} // namespace autograd
} // namespace torch