| #include <torch/csrc/autograd/FunctionsManual.h> |
| #include <torch/csrc/autograd/variable.h> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/BatchedTensorImpl.h> |
| #include <ATen/core/Reduction.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/ExpandUtils.h> |
| #include <ATen/native/IndexingUtils.h> |
| #include <ATen/native/LinearAlgebraUtils.h> |
| #include <ATen/ScalarOps.h> |
| #include <ATen/SparseTensorUtils.h> |
| #include <ATen/Utils.h> |
| #include <ATen/WrapDimUtils.h> |
| #include <ATen/WrapDimUtilsMulti.h> |
| #include <c10/core/TensorOptions.h> |
| #include <c10/util/accumulate.h> |
| |
| #include <ciso646> |
| #include <algorithm> |
| #include <numeric> |
| #include <functional> |
| // Helper functions for autogenerated code |
| // These used to be inlined into the codegened Functions.cpp |
| |
| namespace torch { |
| namespace autograd { |
| namespace generated { |
| namespace details { |
| |
| using at::Tensor; |
| using at::Scalar; |
| using at::IntArrayRef; |
| using at::TensorList; |
| |
| const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)"; |
| |
| |
| bool isDefined(const c10::optional<Tensor>& t) { |
| return t.has_value() && t->defined(); |
| } |
| |
| bool isFwGradDefined(const c10::optional<Tensor>& t) { |
| return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined(); |
| } |
| |
| Tensor toLegacyTensor(const c10::optional<Tensor>& t) { |
| return t.has_value() ? *t : Tensor(); |
| } |
| |
| Tensor toLegacyFwGrad(const c10::optional<Tensor>& t) { |
| return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor(); |
| } |
| |
| Tensor toLegacyPrimal(const c10::optional<Tensor>& t) { |
| return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor(); |
| } |
| |
| void copy_range(variable_list& out, IndexRange range, const Tensor & t) { |
| AT_ASSERT(range.second <= out.size()); |
| AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output"); |
| out[range.first] = t; |
| } |
| |
| void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) { |
| AT_ASSERT(range.second <= out.size()); |
| AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output"); |
| std::copy(t.begin(), t.end(), out.begin() + range.first); |
| } |
| |
| Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) { |
| auto ratio = result / self; |
| ratio.masked_fill_(self == 0, 0); |
| return grad * ratio; |
| } |
| |
| template <typename T> |
| T not_implemented_base(const char* name, const char* reason) { |
| std::string msg = c10::str("the derivative for '", name, "' is not implemented."); |
| if (strlen(reason) > 0) { |
| msg = c10::str(msg, " ", reason); |
| }; |
| throw std::runtime_error(msg); |
| } |
| |
| Tensor not_implemented(const char* name, const char* reason) { |
| return not_implemented_base<Tensor>(name, reason); |
| } |
| |
| std::vector<Tensor> not_implemented_list(const char* name, const char* reason) { |
| return not_implemented_base<std::vector<Tensor>>(name, reason); |
| } |
| |
| Tensor maybe_multiply(const Tensor & t, const Scalar & s) { |
| bool is_one = false; |
| if (s.isFloatingPoint()) { |
| is_one = s.toDouble() == 1; |
| } else if(s.isIntegral(true)) { |
| is_one = s.toLong() == 1; |
| } |
| |
| if (is_one) { |
| return t; |
| } else { |
| return t * s; |
| } |
| } |
| |
| int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) { |
| int64_t size = 1; |
| if (sizes.size() == 0) { |
| return 1; |
| } |
| for (auto d : dim) { |
| d = at::maybe_wrap_dim(d, sizes.size()); |
| size *= sizes[d]; |
| } |
| return size; |
| } |
| |
| static Tensor wrapped_scalar_tensor(Scalar scalar) { |
| auto tensor = scalar_to_tensor(scalar); |
| tensor.unsafeGetTensorImpl()->set_wrapped_number(true); |
| return tensor; |
| } |
| |
| Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) { |
| if (!at::isComplexType(self_st) && gradient_result.is_complex()) { |
| // R -> C |
| return at::real(gradient_result); |
| } |
| return gradient_result; |
| } |
| |
| Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { |
| if (!self.is_complex() && gradient_result.is_complex()) { |
| // R -> C |
| return at::real(gradient_result); |
| } |
| return gradient_result; |
| } |
| |
| Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) { |
| if (keepdim) { |
| return output; |
| } |
| int64_t total_dims = output.dim() + dims.size(); |
| std::vector<int64_t> target_shape(total_dims, 0); |
| for (int64_t i : dims) { |
| if (i < 0) { |
| i = total_dims + i; |
| } |
| target_shape[i] = 1; |
| } |
| int64_t j = 0; |
| for (int64_t i : output.sizes()) { |
| while (target_shape[j] > 0) j++; |
| target_shape[j++] = i; |
| } |
| return output.reshape(target_shape); |
| } |
| |
| Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) { |
| return (grad / mask.sum(dims, true)) * mask; |
| } |
| |
| std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) { |
| if (!grad.defined()) { |
| return std::tuple<Tensor, Tensor>(Tensor(), Tensor()); |
| } |
| // handle case at 0 where we return a subgradient containing 0 |
| Tensor ratio = grad / res; |
| ratio.masked_fill_(res == 0, 0); |
| return std::tuple<Tensor, Tensor>{ |
| x1 * ratio.sum(-1, true) - ratio.matmul(x2), |
| x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)}; |
| } |
| |
| Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional<Scalar> & p_, const Tensor& norm) { |
| return norm_backward(grad, self, p_, norm, {}, true); |
| } |
| |
| Tensor norm_backward(Tensor grad, const Tensor& self, const optional<Scalar> & p_, Tensor norm, IntArrayRef dim, bool keepdim) { |
| size_t ndim = self.sizes().size(); |
| double p = p_.value_or(2.0).toDouble(); |
| Tensor self_scaled; |
| Tensor scale_v; |
| |
| if (!keepdim && self.dim() != 0) { |
| grad = unsqueeze_multiple(grad, dim, ndim); |
| norm = unsqueeze_multiple(norm, dim, ndim); |
| } |
| |
| if (p == 0.0) { |
| return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } else if (p == 1.0) { |
| return self.sgn() * grad; |
| } else if (p == 2.0) { |
| self_scaled = self; |
| scale_v = grad / norm; |
| } else if (std::isinf(p)) { |
| Tensor is_eq_max = (self.abs() == norm).logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self); |
| self_scaled = self.sign() * is_eq_max; |
| Tensor nb_max = is_eq_max.count_nonzero(dim); |
| if (self.dim() != 0) { |
| nb_max = unsqueeze_multiple(nb_max, dim, ndim); |
| } |
| scale_v = grad / nb_max; |
| } else if (p < 2.0) { |
| self_scaled = self.sgn() * self.abs().pow(p - 1); |
| scale_v = grad / norm.pow(p - 1); |
| } else { |
| self_scaled = self * self.abs().pow(p - 2); |
| scale_v = grad / norm.pow(p - 1); |
| } |
| // handle case at 0 where we return a subgradient containing 0 |
| scale_v.masked_fill_(norm == 0, 0); |
| return self_scaled * scale_v; |
| } |
| |
| Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { |
| if (exponent.equal(0.0)) { |
| return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } else { |
| auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; |
| Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); |
| return handle_r_to_c(self, out); |
| } |
| } |
| |
| Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { |
| auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj()); |
| return handle_r_to_c(self, out); |
| } |
| |
| // Caveats: |
| // We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to |
| // d(a^b)/db -> -inf for a fixed b as a -> +0 |
| // Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0. |
| // |
| // We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as |
| // d(a^b)/db = 0 for a > 0 and b -> +0. |
| // Currently, tensorflow agrees with us. |
| Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { |
| Tensor cond; |
| if (exponent.is_complex()) { |
| auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); |
| cond = at::logical_and(self == 0, is_real_exp); |
| } else { |
| cond = at::logical_and(self == 0, exponent >= 0); |
| } |
| auto out = grad * at::where(cond, |
| at::zeros({}, grad.options()), |
| (result * self.log()).conj()); |
| return handle_r_to_c(exponent, out); |
| } |
| |
| Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { |
| auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; |
| if (base.equal(0.0)) { |
| auto cond = [](auto exp) { |
| if (exp.is_complex()) { |
| return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); |
| } else { |
| return exp >=0; |
| } |
| }; |
| auto out = grad * at::where(cond(exponent), |
| at::zeros({}, grad.options()), |
| grad_lambda(result, base)); |
| return handle_r_to_c(exponent, out); |
| } else { |
| auto out = grad * grad_lambda(result, base); |
| return handle_r_to_c(exponent, out); |
| } |
| } |
| |
| Tensor angle_backward(Tensor grad, const Tensor& self) { |
| if (self.is_complex()) { |
| return at::where(self == 0.0, at::zeros({}, self.options()), |
| grad * self / self.abs().pow(2) * Scalar(c10::complex<double>{0.0, 1.0})); |
| } else { |
| return at::zeros_like(self, at::MemoryFormat::Preserve); |
| } |
| } |
| |
| Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) { |
| Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options()); |
| args = args.add(self.unsqueeze(-1)); |
| return grad * args.digamma_().sum(-1); |
| } |
| |
| Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { |
| if (self.is_complex()) { |
| auto abs = at::abs(self); |
| // C -> C |
| // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 |
| return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result))); |
| } else { |
| return at::zeros_like(self, at::MemoryFormat::Preserve); |
| } |
| } |
| |
| Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { |
| auto out = grad * other.conj(); |
| return handle_r_to_c(self_st, out); |
| } |
| |
| Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, c10::string_view rounding_mode) { |
| if (rounding_mode != "true") { |
| return at::zeros_like(grad, grad.options().dtype(self_st)); |
| } |
| |
| auto result = grad / other.conj(); |
| return handle_r_to_c(self_st, result); |
| } |
| |
| Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { |
| return div_tensor_self_backward(grad, other, self_st, "true"); |
| } |
| |
| Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, c10::string_view rounding_mode) { |
| if (rounding_mode != "true") { |
| return at::zeros_like(grad, grad.options().dtype(other.scalar_type())); |
| } |
| |
| auto result = -grad * ((self / other) / other).conj(); |
| return handle_r_to_c(other, result); |
| } |
| |
| Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { |
| return div_tensor_other_backward(grad, self, other, "true"); |
| } |
| |
| Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { |
| // invert the permutation |
| auto ndims = fwd_dims.size(); |
| std::vector<int64_t> dims(ndims); |
| for (size_t i = 0; i < ndims; i++) { |
| dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; |
| } |
| return grad.permute(dims); |
| } |
| |
| Tensor rad2deg_backward(const Tensor& grad) { |
| constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564; |
| return at::mul(grad, wrapped_scalar_tensor(Scalar(M_180_PI))); |
| } |
| |
| Tensor deg2rad_backward(const Tensor& grad) { |
| constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417; |
| return at::mul(grad, wrapped_scalar_tensor(Scalar(M_PI_180))); |
| } |
| |
| Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) { |
| auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); |
| Tensor res = t; |
| for (size_t i = 0; i < n_dims; i++){ |
| if (dims_to_unsqueeze[i]) { |
| res = res.unsqueeze(i); |
| } |
| } |
| return res; |
| } |
| |
| Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) { |
| if (!keepdim && sizes.size() > 0) { |
| if (dims.size()==1) { |
| return grad.unsqueeze(dims[0]).expand(sizes); |
| } else { |
| Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); |
| return res.expand(sizes); |
| } |
| } else { |
| return grad.expand(sizes); |
| } |
| } |
| |
| Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) { |
| auto sizes = self.sizes(); |
| if (!keepdim && sizes.size() > 0) { |
| if (dims.size()==1) { |
| return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not(); |
| } else { |
| Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); |
| return res.expand(sizes) * self.isnan().logical_not(); |
| } |
| } else { |
| return grad.expand(sizes) * self.isnan().logical_not(); |
| } |
| } |
| |
| std::vector<int64_t> reverse_list(const IntArrayRef list) { |
| auto result = std::vector<int64_t>(); |
| result.reserve(list.size()); |
| for (auto iter = list.rbegin(); iter != list.rend(); iter++) { |
| result.push_back(*iter); |
| } |
| return result; |
| } |
| |
| Tensor reverse_dim(const Tensor& t, int64_t dim) { |
| Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong)); |
| return t.index_select(dim, index); |
| } |
| |
| Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) { |
| if (inp.size(dim) == 1) { |
| return grad; |
| } |
| |
| auto ones_size = inp.sizes().vec(); |
| ones_size[dim] = 1; |
| Tensor ones = at::ones(ones_size, grad.options()); |
| Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim); |
| Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim); |
| |
| Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim); |
| Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); |
| Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); |
| |
| return grad * (exclusive_normal * exclusive_reverse); |
| } |
| |
| // note that the gradient for prod is equivalent to: |
| // cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.: |
| // input: [ a, b, c] |
| // cumprod(exclusive, normal): [1 , a, a * b] |
| // cumprod(exclusive, reverse): [b * c, c, 1] |
| // product: [b * c, a * c, a * b] |
| // and this is safe under input with 0s. |
| Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) { |
| if (input.dim() == 0) { |
| return grad; |
| } |
| Tensor zero_idx = (input == 0).nonzero(); |
| if (zero_idx.numel() == 0) { |
| return (grad * result) / input; |
| } else if (zero_idx.size(0) > 1) { |
| return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } else { |
| return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input); |
| } |
| } |
| |
| Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) { |
| if (input.dim() == 0) { |
| return grad; |
| } |
| dim = at::maybe_wrap_dim(dim, input.sizes().size()); |
| if (!keepdim && input.dim() != 1) { |
| grad = grad.unsqueeze(dim); |
| result = result.unsqueeze(dim); |
| } |
| |
| Tensor zero_mask = (input == 0); |
| Tensor slice_zero_count = zero_mask.sum(dim, true); |
| int64_t total_zeros = slice_zero_count.sum().item<int64_t>(); |
| if (total_zeros == 0) { |
| return (grad * result) / input; |
| } else { |
| return prod_safe_zeros_backward(grad, input, dim); |
| } |
| } |
| |
| Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { |
| return at::linalg_solve(A.conj().transpose(-2, -1), grad); |
| } |
| |
| Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { |
| Tensor grad_self = solve_backward_self(grad, self, A); |
| if (self.ndimension() == 2 && A.ndimension() == 2) { |
| return -at::mm(grad_self, solution.conj().transpose(-2, -1)); |
| } |
| // if self was unsqueezed from (..., M) to (..., M, 1) |
| auto batched_rhs_shape = IntArrayRef(A.sizes().data(), A.dim()-1); // A.shape[:-1] |
| bool is_rhs_broadcasted = self.dim() == 1 || (A.dim()-1 == self.dim() && self.sizes().equals(batched_rhs_shape)); |
| if (is_rhs_broadcasted) { |
| return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).conj().transpose(-2, -1)); |
| } |
| return -at::matmul(grad_self, solution.conj().transpose(-2, -1)); |
| } |
| |
| Tensor cumsum_backward(const Tensor & x, int64_t dim) { |
| // Need to check numel to see if there are no values (such as shape [0,2], and dim to see if x is a scalar. |
| if (x.dim() == 0 || x.numel() == 0) { |
| return x; |
| } |
| auto ret = at::cumsum(-x, dim); |
| auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(at::MemoryFormat::Preserve); |
| ret -= ret_sum.expand(ret.sizes()); |
| ret += x; |
| return ret; |
| } |
| |
| Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { |
| if (!keepdim && self.dim() != 0) { |
| grad = unsqueeze_multiple(grad, dim, self.sizes().size()); |
| result = unsqueeze_multiple(result, dim, self.sizes().size()); |
| } |
| return grad * (self - result).exp(); |
| } |
| |
| Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) { |
| if (grad.dim() == 0 || grad.numel() == 0) { |
| return grad; |
| } |
| |
| // Reference: https://github.com/tensorflow/tensorflow/blob/ |
| // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 |
| return AT_DISPATCH_FLOATING_TYPES( |
| at::typeMetaToScalarType(grad.dtype()), |
| "logcumsumexp_backward", |
| [grad, self, result, dim]() { |
| auto grad_min = at::empty_like(grad); |
| grad_min.fill_(std::numeric_limits<scalar_t>::lowest()); |
| auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min); |
| auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min); |
| |
| auto reverse_logcumsumexp = [dim](auto x) { |
| return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim}); |
| }; |
| |
| auto output_pos = |
| (reverse_logcumsumexp(log_grad_positive - result) + self).exp(); |
| auto output_neg = |
| (reverse_logcumsumexp(log_grad_negative - result) + self).exp(); |
| |
| return output_pos - output_neg; |
| }); |
| } |
| |
| Tensor unbind_backward(const variable_list& grads, int64_t dim) { |
| IntArrayRef sizes; |
| at::TensorOptions o; |
| for (auto v : grads) { |
| if (v.defined()) { |
| sizes = v.sizes(); |
| o = static_cast<Tensor>(v).options(); |
| break; |
| } |
| } |
| auto grads_tensors = fmap(grads, [&](const Variable& v) { |
| return ( |
| v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes)); |
| }); |
| return at::stack(grads_tensors, dim); |
| } |
| |
| Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { |
| auto result = self; |
| |
| int64_t nDims = sizes.size(); |
| for (int64_t dim = 0; dim < nDims; dim++) { |
| if (sizes[dim] == 1) { |
| result = result.unsqueeze(dim); |
| } |
| } |
| return result; |
| } |
| |
| Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided |
| // unsqueezing in the backward. |
| if (sizes.size() > 0 && sizes[dim] == 1) { |
| return self.unsqueeze(dim); |
| } |
| return self; |
| } |
| |
| std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim) { |
| std::vector<Tensor> grad_inputs(sizes.size()); |
| if (!grad.defined()) { |
| return grad_inputs; |
| } |
| dim = at::legacy_cat_wrap_dim(dim, sizes); |
| int64_t accumulate = 0; |
| |
| Tensor grad_; |
| bool grad_is_complex = grad.is_complex(); |
| if (grad_is_complex) { |
| grad_ = at::real(grad); |
| } |
| for (size_t i = 0; i < sizes.size(); ++i) { |
| Tensor grad_val; |
| if (!at::isComplexType(dtypes[i]) && grad_is_complex) { |
| // R -> C |
| grad_val = grad_; |
| } else { |
| grad_val = grad; |
| } |
| auto& shape = sizes[i]; |
| // If input was empty tensor, gradInput should be empty tensor. |
| if (shape == std::vector<int64_t>({0})) { |
| grad_inputs[i] = at::zeros({0}, grad_val.options()); |
| continue; |
| } |
| auto size = shape[dim]; |
| accumulate += size; |
| grad_inputs[i] = grad_val.narrow(dim, accumulate - size, size); |
| } |
| return grad_inputs; |
| } |
| |
| Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) { |
| // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases. |
| if (max && min) { |
| return grad * ((self >= *min) * (self <= *max)).type_as(grad); |
| } else if (min) { |
| return grad * (self >= *min).type_as(grad); |
| } else if (max) { |
| return grad * (self <= *max).type_as(grad); |
| } else { |
| return grad; |
| } |
| } |
| |
| // This function is used by load_derivatives.py to replace tensor.strides() |
| // calls that appear in derivative formulas. If the tensor has requires_grad |
| // set, this function returns its strides or throws an error if the tensor |
| // is sparse. If requires_grad is not set, an empty array is returned since |
| // there will be no backward pass. There has one special case, if input is MKLDNN |
| // tensor and has requires_grad set, just return an empty array, the reason is |
| // that MKLDNN tensor is a opaque tensor which has not stride info. |
| // |
| // This function only supports the case where `input` is the tensor whose |
| // single derivative is being calculated. |
| // |
| // This function does not support `self` derivatives for inplace functions. |
| // |
| // Args: |
| // input Tensor to call .strides() on |
| // input_name Name of `input` tensor, from derivative formula |
| at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) { |
| // TODO: Ideally, this function would never be called if requires_grad is |
| // not set. Once codegen is updated to avoid the call, we can remove this |
| // check. |
| if (input.requires_grad()) { |
| TORCH_CHECK( |
| !input.is_sparse(), |
| "The backward pass for this operation requires the '", input_name, |
| "' tensor to be strided, but a sparse tensor was given instead. ", |
| "Please either use a strided tensor or set requires_grad=False for '", |
| input_name, "'"); |
| if (input.is_mkldnn()) return IntArrayRef({}); |
| return input.strides(); |
| } else { |
| return IntArrayRef({}); |
| } |
| } |
| |
| Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha) { |
| // if input was column-major, return grad as column-order for efficiency |
| if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) { |
| return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj()); |
| } else { |
| return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj()); |
| } |
| } |
| |
| Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) { |
| // if input was column-major, return grad as column-order for efficiency |
| if (strides[0] == 1 && strides[1] == sizes[0]) { |
| if (mat1.is_sparse()) { |
| // Since mm(dense, sparse) doesn't exist, |
| // pass a transposed output matrix to the underlying "addmm" |
| // function directly. |
| int64_t out_rows = mat1.size(1); |
| int64_t out_cols = grad.size(1); |
| Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); |
| Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); |
| at::addmm_out(r, t, mat1.t(), grad, alpha, 1); |
| return r; |
| } |
| return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj()); |
| } else { |
| return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj()); |
| } |
| } |
| |
| Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) { |
| AT_ASSERT(sparse_.is_sparse()); |
| auto sparse = sparse_.coalesce(); |
| Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha); |
| return grad_sparse.sparse_mask(sparse); |
| } |
| |
| // This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask` |
| // and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2, |
| // and they must have the same shape. |
| // Note that the `output` must have the same `indices` as the `mask` so we are using just a clone. |
| // However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values` |
| // That's why we created this `_sparse_mask_helper` function. |
| Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){ |
| Tensor output = at::empty_like(mask); |
| Tensor mask_indices = mask._indices().clone(); |
| Tensor r_values; |
| if (mask._nnz() == 0) { |
| r_values = at::zeros_like(mask._values()); |
| } else { |
| r_values = _sparse_mask_helper(input, mask_indices.contiguous()); |
| } |
| at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values); |
| return output; |
| } |
| |
| Tensor sparse_sparse_matmul_backward( |
| const Tensor& grad, |
| const Tensor& a, |
| const Tensor& b, |
| int64_t grad_order) { |
| /* |
| To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition |
| for dense tensors: |
| |
| c = a @ b |
| then |
| a_grad = c_grad @ b^T |
| b_grad = a^T @ c_grad |
| |
| So for sparse matrices we can use the following definition: |
| |
| if grad_order == 0: |
| a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a) |
| else: |
| b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b) |
| */ |
| TORCH_CHECK( |
| grad_order == 0 || grad_order == 1, |
| ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function"); |
| if (grad_order == 0) { |
| auto a_grad = _sparse_sparse_matmul(grad, b.t()); |
| return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce()); |
| } |
| auto b_grad = _sparse_sparse_matmul(a.t(), grad); |
| return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce()); |
| } |
| |
| Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { |
| auto transposed_sizes = self.transpose(dim, 0).sizes().vec(); |
| auto flatten = [&](const Tensor & t) { |
| return t.transpose(dim, 0).contiguous().view({t.size(dim), -1}); |
| }; |
| auto unflatten = [&](const Tensor & t) { |
| return t.contiguous().view(transposed_sizes).transpose(dim, 0); |
| }; |
| |
| // renorm computes the norm over all dimensions except `dim`, which is why |
| // we need the flatten and unflatten business. TODO: simplify this when we |
| // add support for norm over multiple dimensions. |
| auto self_flat = flatten(self); |
| auto grad_flat = flatten(grad); |
| auto norm_flat = self_flat.norm(p, 1, true); |
| auto grad_output = (self_flat * grad_flat).sum(1, true); |
| auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true); |
| auto invnorm = (norm_flat + 1e-7).reciprocal(); |
| auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb)); |
| auto norm = unflatten(norm_flat.expand_as(self_flat)); |
| |
| // TODO: remove the detach once comparison ops no longer require grad |
| auto mask = Variable(norm < maxnorm).detach(); |
| return at::where(mask, grad, grad_norm); |
| } |
| |
| Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) { |
| auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0); |
| if (find_iter != repeats.cend()) { |
| return at::zeros(input_shape, grad.options()); |
| } |
| const auto input_dims = input_shape.size(); |
| int64_t num_unsqueezed = grad.dim() - input_dims; |
| for (int64_t i = 0; i < num_unsqueezed; ++i) { |
| grad = grad.sum(0, false); |
| } |
| |
| at::DimVector grad_size, sum_dims; |
| for (size_t dim = 0; dim < input_dims; ++dim) { |
| int64_t repeat = repeats[dim + num_unsqueezed]; |
| // Reshape gradient (repeat > 1) |
| // Index: [..., dim , ...] [..., dim , dim+1 , ...] |
| // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...] |
| // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor. |
| // Then, sum up gradients over repeated tensors along 'dim', and reduce shape |
| // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize'). |
| // Example: |
| // Size(3, 2) Size(6, 2) |
| // [[v1_0, v1_1], |
| // [v1_2, v1_3], |
| // [[v0, v1], repeat(2, 1) [v1_4, v1_5], |
| // [v2, v3], -------------> [v2_0, v2_1], |
| // [v4, v5]] [v2_2, v2_3], |
| // [v2_4, v2_5]] |
| // |
| // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) |
| // [[[g1_0, g1_1], [[g1_0, g1_1], |
| // [g1_2, g1_3], [g1_2, g1_3], |
| // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], |
| // [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1], |
| // [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], |
| // [g2_2, g2_3], [g2_4, g2_5]] |
| // [g2_4, g2_5]]] |
| // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then |
| // sum over 'dim+1'. The gradient for input is not correctly aligned with input. |
| // Example: |
| // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) |
| // [[[g1_0, g1_1], |
| // [g1_2, g1_3]], [[g1_0, g1_1], |
| // [g1_2, g1_3], |
| // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], |
| // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], |
| // [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], |
| // [[g2_2, g2_3], [g2_4, g2_5]] |
| // [g2_4, g2_5]]] |
| if (repeat != 1) { |
| grad_size.push_back(repeat); |
| sum_dims.push_back(grad_size.size() - 1); |
| } |
| // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1) |
| grad_size.push_back(input_shape[dim]); |
| } |
| // One-time Reshape & Sum |
| // Reshape gradient to grad_size: |
| // 1. If repeat equals to 1, append input size at that dimension, |
| // 2. If repeat is larger than 1, append both repeat and input size at that dimension. |
| // Sum over all "repeat" dimensions from sum_dims: |
| // Example: |
| // Input Size (2, 3, 4, 5) |
| // repeat [4, 1, 9, 3] |
| // output/grad Size (8, 3, 36, 15) |
| // grad_size [4, 2, 3, 9, 4, 3, 5] |
| // sum_dims [0, 3, 5] |
| |
| // When repeat 1 time over all original dimensions, the empty sum_dims will reduce |
| // the whole grad tensor into a scalar rather than keeping original dimensions. |
| if (!sum_dims.empty()) { |
| grad = grad.reshape(grad_size); |
| grad = grad.sum(sum_dims); |
| } |
| return grad; |
| } |
| |
| // p1m == 1 - p |
| Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) { |
| if (grad.requires_grad()) { |
| // Use autograd-friendly backward if double backward is required |
| return grad * (mask.type_as(grad) * (1. / p1m)); |
| } else { |
| return at::_masked_scale(grad, mask, 1. / p1m); |
| } |
| } |
| |
| Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) { |
| if (input.is_cuda()) { |
| auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan())); |
| return mask * (grad / mask.sum()); |
| } else { |
| auto mask = value.isnan().item<bool>() ? input.isnan() : input == value; |
| return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum()); |
| } |
| } |
| |
| Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) { |
| return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean()); |
| } |
| |
| Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { |
| if (self.dim() == 0) { |
| return var_backward(grad, self, unbiased); |
| } |
| if (!keepdim && self.dim() > 1) { |
| grad = unsqueeze_multiple(grad, dim, self.sizes().size()); |
| } |
| return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true)); |
| } |
| |
| Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) { |
| return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, unbiased); |
| } |
| |
| Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { |
| return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, dim, unbiased, keepdim); |
| } |
| |
| Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) { |
| return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim); |
| } |
| |
| Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int numel) { |
| return grad.expand(sizes) / numel; |
| } |
| |
| Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, IntArrayRef dim, bool unbiased, bool keepdim, bool is_std) { |
| Tensor grad; |
| if (grads[0].defined()) { |
| grad = is_std ? std_backward(r1, grads[0], self, dim, unbiased, keepdim) : var_backward(grads[0], self, dim, unbiased, keepdim); |
| } |
| if (grads[1].defined()) { |
| Tensor mean_grad = mean_backward(grads[1], self.sizes(), dim, keepdim); |
| grad = grads[0].defined() ? grad + mean_grad : mean_grad; |
| } |
| return grad; |
| } |
| |
| Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, bool unbiased, bool is_std) { |
| Tensor grad; |
| if (grads[0].defined()) { |
| grad = is_std ? std_backward(r1, grads[0], self, unbiased) : var_backward(grads[0], self, unbiased); |
| } |
| if (grads[1].defined()) { |
| Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel()); |
| grad = grads[0].defined() ? grad + mean_grad : mean_grad; |
| } |
| return grad; |
| } |
| |
| Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) { |
| int64_t numel = 1; |
| for (auto size : sizes) { |
| numel *= size; |
| } |
| auto mask_selected = grad.masked_select(mask); |
| auto diff_nelem = numel - mask_selected.numel(); |
| if (diff_nelem > 0) { |
| // because mask_selected returns a 1-d tensor with size of masked elements that are 1, |
| // we need to fill out the rest with zeros then reshape back to tensor2's size. |
| auto zeros_fillin = at::zeros({diff_nelem}, grad.options()); |
| mask_selected = at::cat({mask_selected, zeros_fillin}, 0); |
| } |
| return mask_selected.view(sizes); |
| } |
| |
| Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { |
| // cf. Iain Murray (2016); arXiv 1602.07527 |
| // This gradient is symmetric, and not triangular. |
| // Cholesky additionally assumes that the input is symmetric, which is a subspace of |
| // R^{n x n}, and hence the derivative is not well-defined for off-diagonal |
| // elements. We resolve this by taking the gradient of the functionally independent |
| // elements of the matrix (i.e., the lower triangular portion of the input) and then |
| // reflect it on the upper triangular portion, thereby symmetrizing the gradient of |
| // the cholesky operation. The motivation behind this choice is that symmetric gradient |
| // leads to stable gradient updates, and retains symmetry of the updated matrix if it |
| // were updated by a gradient based algorithm. |
| if (upper) { |
| L = L.transpose(-1, -2).conj(); |
| grad = grad.transpose(-1, -2).conj(); |
| } |
| auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false)); |
| auto phi = at::matmul(L.transpose(-1, -2).conj(), grad); |
| phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); |
| |
| auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse); |
| return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5); // Symmetrizing the gradient |
| } |
| |
| Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) { |
| Tensor grad_L; |
| if (grad.defined()) { |
| Tensor common_term = grad + grad.transpose(-2, -1); |
| common_term = at::matmul(inverse, at::matmul(common_term, inverse)); |
| if (upper) { |
| grad_L = -at::matmul(L, common_term); |
| } else { |
| grad_L = -at::matmul(common_term, L); |
| } |
| } else { |
| grad_L = at::zeros({1}, L.options()).expand_as(L); |
| } |
| return grad_L; |
| } |
| |
| Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads, |
| IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| |
| // it's possible some of the grads are not defined (represents tensors of all 0s). |
| // Since at::cat can't handle those, let's define them |
| std::vector<Tensor> grads_all_defined(grads.size()); |
| for (size_t j = 0; j < grads.size(); ++j) { |
| if (grads[j].defined()) { |
| grads_all_defined[j] = grads[j]; |
| } else { |
| auto length = split_sizes[j]; |
| auto grad_size = sizes.vec(); |
| grad_size[dim] = length; |
| grads_all_defined[j] = at::zeros(grad_size, options); |
| } |
| } |
| |
| auto ret = at::cat(grads_all_defined, dim); |
| return ret; |
| } |
| |
| Tensor split_backward(const std::vector<torch::autograd::Variable> &grads, |
| int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| int64_t dim_size = sizes[dim]; |
| int64_t num_splits = grads.size(); |
| std::vector<int64_t> split_sizes(num_splits, split_size); |
| split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); |
| return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); |
| } |
| |
| Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { |
| AT_ASSERT(indices.dim() >= dim); |
| auto size = indices.sizes().slice(0, indices.dim() - dim).vec(); |
| size.push_back(-1); |
| auto indices_view = indices.view(size); |
| const auto memory_format = indices.suggest_memory_format(); |
| return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes()); |
| } |
| |
| Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) { |
| auto& gO = grad_output; |
| auto input_size = input.size(dim) / 2; |
| auto first_half = input.narrow(dim, 0, input_size); |
| auto second_half = input.narrow(dim, input_size, input_size); |
| auto sig_second_half = second_half.sigmoid(); |
| auto one_sub_sig_second_half = 1 - sig_second_half; |
| auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half; |
| |
| auto ggI_first_half = grad.narrow(dim, 0, input_size); |
| auto ggI_second_half = grad.narrow(dim, input_size, input_size); |
| auto ggI_second_half_times_first_half = ggI_second_half * first_half; |
| |
| auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig; |
| auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig; |
| auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig; |
| return at::cat({gI_first_half, gI_second_half}, dim); |
| } |
| |
| Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) { |
| if (dim < 0) dim += input.dim(); |
| auto sizes = input.sizes().vec(); |
| sizes[dim] /= 2; |
| auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim); |
| return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]); |
| } |
| |
| Tensor infinitely_differentiable_silu_backward( |
| const Tensor& grad_output, |
| const Tensor& input) { |
| const Tensor sigmoid = input.sigmoid(); |
| return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid)); |
| } |
| |
| Tensor infinitely_differentiable_logit_backward( |
| const Tensor& grad, |
| const Tensor& self, |
| c10::optional<double> eps) { |
| if (eps) { |
| const double lo = eps.value(); |
| const double hi = 1.0 - lo; |
| return at::where( |
| at::logical_and(self >= lo, self <= hi), |
| grad / (self * (1.0 - self)), |
| at::zeros({}, self.options())); |
| } else { |
| return at::where( |
| at::logical_and(self >= 0.0, self <= 1.0), |
| grad / (self * (1.0 - self)), |
| at::empty({}, self.options()) |
| .fill_(std::numeric_limits<double>::quiet_NaN())); |
| } |
| } |
| |
| Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) { |
| auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target); |
| if (reduction == at::Reduction::Mean) { |
| return result.mean(); |
| } else if (reduction == at::Reduction::Sum) { |
| return result.sum(); |
| } |
| return result; |
| } |
| |
| // Compute derivatives for targets. |
| Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) { |
| Tensor grad_target; |
| if (!log_target) { |
| grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.); |
| } |
| else { |
| grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp())); |
| } |
| |
| if (reduction == at::Reduction::Mean) { |
| grad_target.div_(target.numel()); |
| } |
| |
| return grad_target; |
| } |
| |
| Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional<Tensor>& weight, const c10::optional<Tensor>& pos_weight, int64_t reduction) { |
| Tensor grad_target; |
| if (isDefined(pos_weight)) { |
| grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output); |
| } else { |
| grad_target = self.mul(-grad_output); |
| } |
| |
| if (isDefined(weight)) { |
| grad_target.mul_(*weight); |
| } |
| |
| if (reduction == at::Reduction::Mean) { |
| grad_target.div_(target.numel()); |
| } |
| |
| return grad_target; |
| } |
| |
| Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { |
| auto z = input.sigmoid(); |
| return grad * (z - 1) * z; |
| } |
| |
| Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { |
| auto gO = grad_output; |
| auto ggI = grad; |
| |
| auto ggI_output = ggI * output; |
| auto ggI_out_sum = ggI_output.sum(dim, true); |
| auto ggI_out_sum_output = ggI_out_sum * output; |
| auto gO_out_sum = (gO * output).sum(dim, true); |
| |
| // gI calculation |
| auto gI_t0 = ggI_output * (gO - gO_out_sum); |
| auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum)); |
| auto gI_t2 = ggI_out_sum_output * gO; |
| auto gI_t3 = ggI_out_sum_output * gO_out_sum; |
| return gI_t0 - gI_t1 - gI_t2 + gI_t3; |
| } |
| |
| Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { |
| auto z = output.exp(); |
| return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); |
| } |
| |
| // NOTE: [How to write vmap-compatible backward formulas] |
| // |
| // See NOTE: [vmap-incompatible in-place operations] for what it means for an |
| // in-place operation to be incompatible with vmap. |
| // |
| // If an in-place operation used in a backward formula is vmap-incompatible, |
| // then as developers we have the following options: |
| // |
| // - If the in-place operation directly followed the creation of a tensor with |
| // a factory function like at::zeros(...), we should replace the factory with a |
| // corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call |
| // propagates the batch dims to the resulting tensor. |
| // For example: |
| // Before: at::zeros(input.sizes(), grad.options()).copy_(grad) |
| // After: grad.new_zeros(input.sizes()).copy_(grad) |
| // |
| // - If the in-place operation followed some sequence of operations, if the |
| // we want to be able to vmap over the backward formula as-is (this is |
| // usually the case for simple (<15loc) backward formulas), then use |
| // inplaceIsVmapCompatible to guard the operation. For example: |
| // c = a * b |
| // Before: c.mul_(grad) |
| // After: c = at::inplaceIsVmapCompatible(c, grad) ? c.mul_(grad) : c * grad |
| // |
| // - If we don't want to vmap directly over the backward formula (e.g., if the |
| // backward formula is too complicated or has a lot of vmap-incompatible |
| // operations, then register the backward formula as an operator and eventually |
| // write a batching rule for it. |
| |
| Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) { |
| auto eps = 1e-12; |
| auto inp_pl_eps = input + eps; |
| auto one_m_inp_pl_eps = 1 - input + eps; |
| // gradient wrt input |
| auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); |
| if (at::inplaceIsVmapCompatible(gI, grad)) { |
| gI *= (grad * grad_output); |
| } else { |
| gI = gI * (grad * grad_output); |
| } |
| |
| if (isDefined(weight)) { |
| gI *= *weight; |
| } |
| if (reduction == at::Reduction::Mean) { |
| return gI / input.numel(); |
| } else if (reduction == at::Reduction::Sum) { |
| return gI.sum(); |
| } |
| return gI; |
| } |
| |
| Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) { |
| auto eps = 1e-12; |
| // gradient wrt grad_output |
| auto ggO = (input - target) / ((input + eps) * (1 - input + eps)); |
| if (at::inplaceIsVmapCompatible(ggO, grad)) { |
| ggO *= grad; |
| } else { |
| ggO = ggO * grad; |
| } |
| |
| if (isDefined(weight)) { |
| ggO *= *weight; |
| } |
| if (reduction == at::Reduction::Mean) { |
| return ggO / input.numel(); |
| } else if (reduction == at::Reduction::Sum) { |
| return ggO.sum(); |
| } |
| return ggO; |
| } |
| |
| Tensor l1_loss_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & self, const Tensor & other, int64_t reduction) { |
| if (!self.is_complex()) { |
| return at::zeros_like(grad); |
| } else { |
| auto diff = self - other; |
| auto output = grad_output * sgn_backward(diff.sgn(), grad, diff); |
| if (reduction == at::Reduction::Mean) { |
| output /= self.numel(); |
| } |
| return output; |
| } |
| } |
| |
| Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { |
| auto output = at::l1_loss_backward(grad.conj(), input, target, at::Reduction::None); |
| if (reduction == at::Reduction::Mean) { |
| return output.mean(); |
| } else if (reduction == at::Reduction::Sum) { |
| return output.sum(); |
| } |
| return handle_r_to_c(grad_output, output); |
| } |
| |
| Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { |
| // special case to protect against a divide-by-zero. |
| if (beta == 0) { |
| return at::zeros(grad.sizes(), grad.options()); |
| } |
| auto d = (input - target).abs(); |
| auto grad_input = grad * (d < beta).type_as(grad) / beta; |
| if (reduction == at::Reduction::Mean) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { |
| if (reduction == at::Reduction::None) { |
| return smooth_l1_loss_backward(grad, input, target, reduction, beta); |
| } |
| auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta); |
| return (r * grad).sum(); |
| } |
| |
| Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) { |
| auto grad_input = 2 * grad; |
| if (reduction == at::Reduction::Mean) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { |
| if (reduction == at::Reduction::None) { |
| return mse_loss_backward(grad, input, target, reduction); |
| } |
| auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction); |
| return (r * grad).sum(); |
| } |
| |
| Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { |
| auto z = (input * -target).exp(); |
| auto zplus1 = z + 1; |
| auto grad_input = grad * (target * target) * z / (zplus1 * zplus1); |
| if (reduction == at::Reduction::Mean) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { |
| if (reduction == at::Reduction::None) { |
| return soft_margin_loss_backward(grad, input, target, reduction); |
| } |
| auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction); |
| return (r * grad).sum(); |
| } |
| |
| Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) { |
| auto x = (input * beta); |
| return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta; |
| } |
| |
| |
| // NOTE [ as_strided Backward and layout-aware/agnostic autograd ] |
| // |
| // `storage_offset` is ignored for simplicity in this note. If you just want the |
| // full algorithm without explanation, scroll down to bottom of this note. |
| // |
| // Implementing the backward of as_strided is tricky because you have to deal |
| // with mappings that map one memory location to multiple indices, i.e., the |
| // output tensor has multiple indices pointing to **overlapping** memory |
| // addresses. This can happen in all in all sorts of weird cases. For example, |
| // |
| // x = torch.randn(15) |
| // x.as_strided([3, 3], [1, 0]) # "expand" case |
| // x.as_strided([3, 3], [2, 1]) # "size too large" case |
| // x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6 |
| // # res[0, 1] points to 0*3 + 1*6 = 6 |
| // |
| // Here is the general strategy we apply in implementing as_strided backward: |
| // 0. ??? (optimization step. we will talk about this later) |
| // 1. Create some underlying flattened tensor as if it is the base tensor |
| // representing the contiguous memory storage for both input and output. |
| // 2. Use the output geometry to scatter (or index_add) the gradients into |
| // this storage tensor. |
| // 3. ??? (fix for input tensor with overlapping memory. we will talk about |
| // this later) |
| // 4. Return the as_strided view of the storage tensor using input geometry. |
| // |
| // In step (2), if the output tensor does't have overlapping memory, we can |
| // safely scatter (`storage.as_strided(output_geometry).copy_(grad)`); |
| // otherwise, we must use `index_add` as gradients at different indices may need |
| // to be summed to a single location. |
| // |
| // For example, in this case: |
| // |
| // x = torch.randn(3) |
| // y = x.as_strided([3, 3], [1, 0]) # "expand" case |
| // # size [ 3, 3] |
| // # stride [ 1, 0] |
| // y.backward() # step (1): contiguous storagte tensor `s` of size 3, which |
| // is large enough to be used as underlying storage |
| // for `x` and `y`. |
| // s = [ 0, 0, 0] |
| // # step (2): since `y` has overlapping memory, index_add grad |
| // into `s` basing on `y`'s geometry, i.e., |
| // s[i * y.stride(0) + j * y.stride(1)] += gy[i, j]. |
| // s = [ 3, 3, 3] |
| // # step (4): as_strided view `s` using `x`'s geometry |
| // s = [ 3, 3, 3] |
| // grad_input = s.as_strided(x.size(), x.stride()) |
| // = s.as_strided([3], [1]) |
| // = [ 3, 3, 3] |
| // |
| // This is exactly what we would get if using `expand`. However, here the input |
| // tensor doesn't have overlapping memory. If it does, we must add an extra step |
| // before (4). Considering this case: |
| // |
| // t = torch.randn(3) |
| // x = t.expand(3, 3) # input with overlapping memory |
| // # size [3, 3] |
| // # stride [0, 1] |
| // y = x.as_strided([1], [1]) # contiguous output |
| // # size [1] |
| // # stride [1] |
| // y.backward() # step (1): contiguous storage tensor `s` of size 3, which |
| // is large enough to be used as underlying storage |
| // for `x` and `y`. |
| // s = [ 0, 0, 0] |
| // # step (2): scatter grad into `s` basing on `y`'s geometry |
| // s = [ 1, 0, 0] |
| // # step (4): as_strided view `s` using `x`'s geometry |
| // s = [ 1, 0, 0] |
| // grad_input = s.as_strided([3, 3], [0, 1]) |
| // = s.as_strided([3, 3], [0, 1]) |
| // = [[ 1, 0, 0], |
| // [ 1, 0, 0], |
| // [ 1, 0, 0]] |
| // Is this result correct? |
| // |
| // `x.as_strided([1], [1])` call is obviously equivalent with |
| // `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second |
| // gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific |
| // case, indexing `x` at any index in first column is also equivalent, and |
| // yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is |
| // an `x.size(1)`-times difference between these gradients computed from other |
| // PyTorch ops and the gradient we got from as_strided. |
| // |
| // You might conclude that the gradients from as_strided is wrong. However, |
| // let's first see why they are actually reasonable. Consider the pointwise |
| // perturbations by `delta` anywhere in the first column of `x`. It will lead to |
| // a `delta` change in the same memory location, and then `y` will change by |
| // `delta`. So one can say the gradient should be exactly 1 at the first column, |
| // as given by our above procedure. |
| // |
| // In the above computation of numerical gradients, they only match the |
| // analytical results because strides and memory locations are considered in the |
| // forward pass, i.e., this op (including both forward and backward) is |
| // layout-aware. |
| // |
| // However, in PyTorch, most (probably all) other ops (forward and backward) are |
| // layout-agnostic. E.g., |
| // |
| // t = torch.randn(1) |
| // x = t.expand(2) |
| // y = x.sum() |
| // y.backward() |
| // |
| // Layout-agnostic autograd (as it is currently in PyTorch) will give you |
| // |
| // gy = 1 |
| // gx = [ 1, 1] # SumBackward: torch.ones_like(x) |
| // gt = [ 2] # ExpandBackward: gx.sum() |
| // |
| // Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta` |
| // (the other will also change by `delta`), `y` will change by `2 * delta`. So |
| // the gradients, if strides are taken into consideration, should be 2. |
| // |
| // Layout-aware autograd should give you |
| // |
| // gy = 1 |
| // gx = [ 2, 2] # Because the backward considers the fact that the input `x` |
| // # is already expanded. |
| // gt = [ 2] # Layout-aware backward of expand is just a slicing because |
| // # the previous backward should have already taken care of |
| // # strides and made sure that gradients are the same along the |
| // # expanded dimension. |
| // |
| // As shown above, these two types are not compatible. Therefore, we must either |
| // make as_strided layout-agnostic, or make all other ops layout-aware. |
| // |
| // It is difficult to support layout-aware autograd (at least in the current |
| // codebase structure), because it would mean |
| // 1. storing tensor geometries of every input tensor for backward |
| // 2. depending on input geometry, the gradient computed from backward change |
| // 3. ideally enforcing gradient of T to always have same strides as T |
| // (although these two methods only differ when it comes to overlapping memory) |
| // |
| // Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e., |
| // giving the same output regardless of the input layout. We consider |
| // `input.stride()` as a separate independent fixed argument `input_stride`. |
| // Then, `as_strided(input, size, stride)` can be thought of as: |
| // 1. "Scatter" each value of `input` into a "storage" using storage location |
| // computed from the value's index in `input`, `input.size()` and |
| // `input_stride`, but if N values end up in the same location, the value |
| // is average of those N values (they will be the same value anyways). |
| // |
| // Formal description: |
| // Denote the set of all input indices that pointing to the same storage |
| // location `storage[n]` as `S(n)`, i.e., |
| // |
| // S(n) = { index : <index, input_stride> == n, index is valid given input.size() }, |
| // |
| // where `<x, y>` is the dot product between `x` and `y`. |
| // |
| // Then, the process is: |
| // |
| // storage[n] = Avg { S(n) } |
| // |
| // Note that all values in `S(n)` are the same (they point to the same |
| // memory location anyways, so this step doesn't change anything, but |
| // effectively avoids having the denpendency on the layout of `input`. |
| // I.e., the result holds fixed regardless of the layout of `input`, as |
| // long as `input_stride` is fixed. |
| // |
| // NOTE: for forward pass, we can equivalently simply selet any one of |
| // `S(n)` as `storage[n]`. However, cosnidering this as an average |
| // operation makes backward easier (so all values in set |
| // `{ grad_input[i] : i in S(n) }` are the same, and it can use the |
| // same geometry as input). |
| // 2. As usual, return the as_strided view of `storage` using required output |
| // `size` and `stride`. |
| // |
| // To backward through this layout-agnostic version, we simply add the following |
| // step: |
| // .... (scatter gradients into the storage tensor using output geometry) |
| // 3. For all storage location n, `storage[n] /= |S(n)|`. |
| // .... (return as_strided view of the storage tensor using input geometry) |
| // |
| // Finally, we note that these general operations are expensive, so we apply the |
| // following optimizations: |
| // Add step (0): For all output dimension `d` with output stride 0, sum the |
| // gradients along dimension `d` (don't keepdim), and remove |
| // dimension `d` from output size and stride. |
| // (An optimization for "expand" cases so we may avoid step (3)) |
| // Only apply step (3) when input tensor has overlapping memory. |
| // |
| // FULL ALGORITHM: |
| // 0. For all output dimension `d` with output stride 0, sum the gradients |
| // along dimension `d` (don't keepdim), and remove dimension `d` from |
| // output size and stride. |
| // 1. Create some underlying flattened tensor as if it is the base tensor |
| // representing the contiguous memory storage for both input and output. |
| // 2. Use the output geometry to scatter (or index_add) the gradients into |
| // this storage tensor `storage`. |
| // 3. If input tensor has overlapping memory, |
| // For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the |
| // number of indices in input geometry pointing to the same storage |
| // location `i` (i.e., `|S(i)|` in equations above). |
| // 4. Return the as_strided view of the storage tensor using input geometry. |
| // |
| // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to |
| // roughly detech overlapping memory. |
| |
| |
| // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // |
| // Checking memory overlap within a strided tensor is the special case of |
| // detecting memory overlap of two strided tensors, where the two tensors start |
| // at the same memory address. The later is HARD (see #8212). |
| // |
| // But even this special case isn't simple. This note describes a check for a |
| // even more constrained simple case where we can be certain that there is no |
| // overlap. |
| // |
| // The checking algorithm can be described as: |
| // 0. Return [ pass check ] if any dimension has size 0 |
| // 1. Ignore all dimensions that have size 1 |
| // 2. If no remaining dimensions, return [ pass check ] |
| // 3. Sort the remaining dimensions according to the strides decreasingly |
| // 4. Check that for each dimension k, |
| // |
| // stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i] |
| // |
| // That is equivalent to, after reordering the dimensions so strides are |
| // in decreasing order, checking that stride of each dimension is larger |
| // than the maximum memory offset in a slice at that dimension. |
| // |
| // Obviously this check passes for contiguous tensors ( the dimensions will be |
| // already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger |
| // than RHS ). Similarly, the check passes for tensors contiguous in all but |
| // the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being |
| // exactly stride[-1] larger than RHS. (*) |
| // |
| // We will show that these view operations, including all our view operations |
| // *except for* general as_strided and unfold, also preserve this invariant: |
| // |
| // alias: Obviously preserves |
| // |
| // expand: All changed dimensions are removed in step (1) |
| // |
| // view: Consider the input dimensions as grouped into consecutive |
| // dimension "blocks", where dimensions are contiguous in each one. |
| // one. view only works when the output dimensions can also be |
| // grouped into the same consecutive blocks of same ordering. |
| // |
| // NB: this means that the number of elements and stride of the |
| // last dimension in each block is the same in input and |
| // output. (**) |
| // |
| // Notation: |
| // Consider a single such block B, |
| // ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ... |
| // start--^^^^ ^^^^^^^^^^^^--end |
| // Each B[i] denotes a dimension index such that B[i] = B[0] + i. |
| // |
| // We first show that in a tensor (i.e., input) satisfies the |
| // invariant, after sorting, the dimensions within each block |
| // still remain consecutive. (***) |
| // |
| // After removing dimensions of size 1, the dimensions within a |
| // block is already sorted by strides in descending order. So |
| // sorting all dimensions will not change the relative ordering |
| // among them. |
| // |
| // Assume that some block B is not consecutive after sorting, |
| // i.e., there exists a dimension d between B[0] and B[-1] in |
| // sorted order. |
| // |
| // By (*), we know that |
| // stride[B[0]] |
| // = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]] |
| // < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d] |
| // <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d] |
| // <= \sum{j > B[0]} (size[j] - 1) * stride[j], |
| // |
| // where the first < comes from sorting and |
| // the second <= comes from the fact that dimension d |
| // exists after step (1) and |
| // thus must have size greater |
| // than 1 |
| // the third <= comes from the fact that each term in |
| // the sum is non-negative |
| // |
| // Then we have a countradiction as the invariant must not be |
| // satisfied at B[0]. So the original proposition is true. |
| // |
| // Now that we established the above claim (***), we consider the |
| // view operation as first sorting the dimensions (i.e., blocks), |
| // apply the original view (since it only cares dimensions being |
| // consecutive and contiguous withtin each block), and then undo |
| // the sort. |
| // |
| // Consider a single block B in the output, |
| // ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ... |
| // start--^^^^ ^^^^^^^^^^^^--end |
| // |
| // By (*), we know that for all i |
| // stride[i] = stride[B[-1]] + |
| // \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]] |
| // |
| // Then the invariant is obviously satisfied at every dimension |
| // in this block if it is satisfied at dimnesion B[-1]. It only |
| // remains to show that it is satisfied at the last dimension in |
| // each block. |
| // |
| // Since the same blocks are present in both input and output |
| // with the same ordering, we will abuse the notation in the |
| // following statements. |
| // |
| // By (*), we know that the following holds for both input and |
| // output, for any block B: |
| // \sum_{i > B[-1]} (size[i] - 1) * stride[i] |
| // = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]] |
| // = \sum_{block B' after B} numel(B') * stride[B'[-1]]. |
| // ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| // By (**), we know that, this quantity in the above equation |
| // remains the same in input and output. So both |
| // \sum_{i > B[-1]} (size[i] - 1) * stride[i] |
| // and |
| // stride[B[-1]] |
| // are the same in input and output. |
| // |
| // These two quantities are exactly the LHS and RHS of the |
| // invariant inequality. Since by assumption the invariant is |
| // satisfied in input at B[-1], it is also satisfied in output at |
| // B[-1]. This concludes the proof. |
| // |
| // squeeze: Special case of view |
| // |
| // unsqueeze: Special case of view |
| // |
| // slice: Consider slicing dimension i with step = k >= 1. |
| // |
| // Let stride' and size' be the output strides and sizes. We have |
| // |
| // stride'[i] = k * stride[i] |
| // size'[i] <= floor(size[i] / k) |
| // |
| // If size'[i] = 1, invariant is obviously satisfied as we are |
| // just removing a dimension (afte step (1)). |
| // |
| // Assume size'[i] > 1. |
| // |
| // By assumption, the invariant is satisfied at every dimension |
| // in input. |
| // |
| // For any dimension j, if stride[j] > stride[i], we have |
| // stride'[j] = stride[j] |
| // > (size[i] - 1) * stride[i] |
| // = (size[i] / k * k - 1) * k * stride[i] / k |
| // = (size[i] / k - 1 / k) * stride'[i] |
| // >= (size'[i] - 1 / k) * stride'[i] |
| // >= stride'[i]. |
| // |
| // If stride[j] < stride[i], we have |
| // stride'[j] = stride[j] < stride[i] <= stride'[i]. |
| // |
| // So the sorting order remains unchanged after slice. |
| // |
| // Since |
| // (size'[i] - 1) * stride'[i] |
| // = (floor(size[i] / k) - 1) * k * stride[i] |
| // <= (size[i] / k - 1) * k * stride[i] |
| // = (size[i] - k) * stride[i] |
| // <= (size[i] - 1) * * stride[i], |
| // the term from this dimension i in the invariant inequality at |
| // other dimensions can only decrease after slice. So the |
| // invariant is preserved. |
| // |
| // narrow: Special case of slice |
| // |
| // select: narrow + squeeze |
| // |
| // permute: Sorting makes permutation of dimensions irrelevant |
| // |
| // transpose: Sorting makes swapping dimensions irrelevant |
| // |
| // diagonal: Effectively merging two dimensions i and j into a new |
| // dimension k s.t. |
| // stride'[k] = stride[i] + stride[j] |
| // size'[k] <= min(size[i], size[j]), |
| // where stride and size are on the input, and stride' and size' |
| // are on the output. |
| // |
| // Assuming that size[i] > 1 and size[j] > 1. If any has size 1, |
| // then this is unsqueeze on that dimension. |
| // |
| // WLOG, say stride[i] >= stride[j]. |
| // |
| // Each dimension d in input with stride[d] > stride[j] has |
| // stride'[d] = stride[d] |
| // > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j] |
| // >= stride[i] + stride[j] |
| // = stride[k]. |
| // So, considering the sorted dimensions, this is effectively |
| // removing i, and replacing j with k. |
| // |
| // For dimensions d with stride[i] < stride[d] < stride[j], the |
| // term from dimension i is removed in the invariant inequality. |
| // For dimensions d with stride[d] > stride[j], we have |
| // (size'[k] - 1) * stride'[k] |
| // <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j]) |
| // <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j], |
| // so the term from i and j in the invariant can only decrease. |
| // |
| // So this is generally relaxing the constraint, and thus it |
| // preserves it. |
| |
| // This implements steps (2)~(4) of the algorithm in |
| // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // Helper for as_strided_backward |
| static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) { |
| if (sizes.size() > 0) { |
| std::vector<std::size_t> argsort(sizes.size()); |
| std::iota(argsort.begin(), argsort.end(), 0); |
| std::sort(argsort.begin(), argsort.end(), |
| [&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; }); |
| |
| int64_t max_index_in_slice = 0; |
| for (auto i : argsort) { |
| auto stride_ = strides[i]; |
| if (stride_ <= max_index_in_slice) { |
| return true; |
| } |
| max_index_in_slice += stride_ * (sizes[i] - 1); |
| } |
| } |
| return false; |
| } |
| |
| // Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset |
| // Helper for as_strided_backward |
| static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { |
| int64_t storage_size = storage_offset + 1; |
| int64_t dim = sizes.size(); |
| for (int64_t i = 0; i < dim; i++) { |
| auto size_i = sizes[i]; |
| if (size_i == 0) { |
| return storage_offset; |
| } |
| storage_size += (size_i - 1) * strides[i]; |
| } |
| return storage_size; |
| } |
| |
| // See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation |
| Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_) { |
| // For output geometry, |
| // check for size 0 dimensions, |
| // skip size 1 dimensions, |
| // reduce grad on expanded dims (stride=0, size>1) |
| // Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] |
| // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // on output geometry |
| auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset()); |
| auto odim = grad.dim(); |
| std::vector<int64_t> out_sizes_, out_strides_; |
| out_sizes_.reserve(odim); |
| out_strides_.reserve(odim); |
| for (int64_t i = odim - 1; i >= 0; i--) { |
| auto size_i = sizes[i]; |
| auto stride_i = strides[i]; |
| if (size_i == 0) { |
| return at::zeros(input_geometry.sizes(), grad.options()); |
| } else if (size_i == 1) { |
| grad = grad.squeeze(i); |
| } else if (stride_i == 0) { |
| grad = grad.sum(i, false); |
| } else { |
| out_sizes_.insert(out_sizes_.begin(), size_i); |
| out_strides_.insert(out_strides_.begin(), stride_i); |
| } |
| } |
| // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // on output geometry |
| auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_); |
| |
| // For input geometry, |
| // check for size 0 dimensions, |
| // skip size 1 dimensions, |
| // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // on input geometry |
| auto idim = input_geometry.dim(); |
| IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides(); |
| std::vector<int64_t> inp_sizes_, inp_strides_; |
| inp_sizes_.reserve(idim); |
| inp_strides_.reserve(idim); |
| for (int64_t i = idim - 1; i >= 0; i--) { |
| auto size_i = inp_sizes[i]; |
| auto stride_i = inp_strides[i]; |
| if (size_i == 0) { |
| return at::zeros(input_geometry.sizes(), grad.options()); |
| } else if (size_i != 1) { |
| inp_sizes_.insert(inp_sizes_.begin(), size_i); |
| inp_strides_.insert(inp_strides_.begin(), stride_i); |
| } |
| } |
| // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] |
| // on input geometry |
| auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_); |
| |
| |
| // Rest of this function implements |
| // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] |
| // TODO: Raise if not all output values are visible in input geometry. |
| // Technically speaking, if you treat those values as constants, not |
| // raising is fine, and mathematically correct. However, these values |
| // really are contained in some base tensor, and by treating them as |
| // constants we are ignoring this tight dependency. Therefore, it is |
| // more sensible to raise here. |
| |
| // Step (1): create underlying tensor as "storage" |
| auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset); |
| auto inp_effective_offset = input_geometry.storage_offset() - shared_offset; |
| auto out_effective_offset = storage_offset - shared_offset; |
| auto base_size = std::max( |
| _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset), |
| _min_storage_size(out_sizes_, out_strides_, out_effective_offset) |
| ); |
| auto storage = grad.new_zeros({base_size}); |
| |
| // prepare indices tensor if we will do index_add_ later |
| c10::optional<at::Tensor> flatten_full_indices; |
| if (inp_maybe_overlap || out_maybe_overlap) { |
| flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong)); |
| } |
| |
| // Step (2): use output geometry to scatter gradients into storage |
| if (out_maybe_overlap) { |
| auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset); |
| storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1)); |
| } else { |
| // assume that new tensors have 0 storage offset |
| storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad); |
| } |
| |
| // Step (3): if input tensor has overlapping memory, divide scattered gradient |
| // at storage[i] by the number of times i shows up in input geometry |
| if (inp_maybe_overlap) { |
| auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1); |
| count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices)); |
| storage.div_(count); // this will give nan outside visible range |
| } |
| // Step (4): return as_strided view of the storage tensor with input geometry |
| return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset); |
| } |
| |
| std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) { |
| if (!grad.defined()) { |
| return std::tuple<Tensor, Tensor>{Tensor(), Tensor()}; |
| } |
| auto recip = (self * self + other * other).reciprocal(); |
| return std::tuple<Tensor,Tensor>{ |
| output_mask[0] ? grad * other * recip : Tensor(), |
| output_mask[1] ? grad * -self * recip : Tensor() }; |
| } |
| |
| // TODO: Seriously consider writing the derivative formulas for |
| // each output separately; there is not all that much sharing |
| // of computation going on here. |
| std::tuple<Tensor, Tensor, Tensor> prelu_double_backward( |
| const Tensor & grad_grad_input, |
| const Tensor & grad_grad_weight, |
| const Tensor & grad_out, |
| const Tensor & input_, |
| const Tensor & weight_) { |
| |
| if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) { |
| return std::tuple<Tensor, Tensor, Tensor>(Tensor(), Tensor(), Tensor()); |
| } |
| auto input = input_.contiguous(); |
| auto weight = weight_.contiguous(); |
| |
| // Zero-fill undefined grads (TODO: do this more efficiently) |
| auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| auto positive_mask = (input > 0).type_as(ggI); |
| auto nonpositive_mask = (input <= 0).type_as(ggW); |
| |
| // Explanation: Let input be i, weight be w, grad_output be gO. |
| // f(i, w) = i if i > 0 |
| // = w * i if i <= 0 |
| // gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0 |
| // = gO * w if i <= 0 = gO * i if i <= 0 |
| // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly. |
| |
| if (weight.numel() == 1) { |
| // from PReLU.forward: num_parameters == 0 is used indicate that a |
| // single weight is shared among all input channels. |
| |
| // this is a little tricky because PReLU currently doesn't take a shape so the weight may be |
| // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable |
| // and tensor are merged). So, use weight and ggW as 0-dim in this case. |
| bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1); |
| auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight; |
| auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; |
| |
| auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input); |
| auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); |
| return std::tuple<Tensor, Tensor, Tensor>( |
| ggO, |
| ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, |
| (ggI * gO * nonpositive_mask).sum().expand_as(weight) |
| ); |
| } else { |
| // Expand ggW to match size of ggI; a simple expand doesn't work because |
| // ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example, |
| // let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1) |
| // so the expand succeeds. |
| auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0); |
| auto ggW_expanded = ggW; |
| for (int64_t i = 0; i < dims_to_unsqueeze; i++) { |
| ggW_expanded = ggW_expanded.unsqueeze(1); |
| } |
| ggW_expanded = ggW_expanded.expand_as(ggI); |
| |
| auto gI = ggW_expanded * gO * nonpositive_mask; |
| |
| auto gW = ggI * gO * nonpositive_mask; |
| if (input.dim() > 1) { |
| gW = gW.sum(0); |
| } |
| while (gW.dim() > 1) { |
| gW = gW.sum(1); |
| } |
| |
| Tensor ggO; |
| if (gO.requires_grad()) { |
| // expand weight as input as in ggW/ggI above |
| auto weight_expanded = weight; |
| for (int64_t i = 0; i < dims_to_unsqueeze; i++) { |
| weight_expanded = weight_expanded.unsqueeze(1); |
| } |
| weight_expanded = weight_expanded.expand_as(input); |
| |
| auto mask = positive_mask + nonpositive_mask * weight_expanded; |
| ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; |
| } |
| return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW}; |
| } |
| } |
| |
| Tensor elu_double_backward( |
| const Tensor& grad, |
| const Tensor& grad_output, |
| Scalar alpha, |
| Scalar scale, |
| Scalar input_scale, |
| bool is_result, |
| const Tensor& self_or_result) { |
| |
| if (is_result) { |
| return grad * grad_output * input_scale * (self_or_result < 0).type_as(grad); |
| } else { |
| return at::elu_backward(grad * grad_output * input_scale, alpha, scale, input_scale, is_result, self_or_result) * (self_or_result < 0).type_as(grad); |
| } |
| } |
| |
| Tensor slice_backward_wrapper( |
| const at::Tensor& grad, |
| const c10::IntArrayRef& input_sizes, |
| int64_t dim, |
| c10::optional<int64_t> start, |
| c10::optional<int64_t> end, |
| int64_t step) { |
| auto start_val = start.has_value() ? start.value() : 0; |
| auto end_val = end.has_value() ? end.value() : INT64_MAX; |
| |
| return slice_backward(grad, input_sizes, dim, start_val, end_val, step); |
| } |
| |
| // https://j-towns.github.io/papers/svd-derivative.pdf |
| // |
| // This makes no assumption on the signs of sigma. |
| Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { |
| TORCH_CHECK(compute_uv, |
| "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", |
| "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); |
| |
| auto m = self.size(-2); |
| auto n = self.size(-1); |
| auto k = sigma.size(-1); |
| auto gsigma = grads[1]; |
| |
| auto u = raw_u; |
| auto v = raw_v; |
| auto gu = grads[0]; |
| auto gv = grads[2]; |
| |
| if (!some) { |
| // We ignore the free subspace here because possible base vectors cancel |
| // each other, e.g., both -v and +v are valid base for a dimension. |
| // Don't assume behavior of any particular implementation of svd. |
| u = raw_u.narrow(-1, 0, k); |
| v = raw_v.narrow(-1, 0, k); |
| if (gu.defined()) { |
| gu = gu.narrow(-1, 0, k); |
| } |
| if (gv.defined()) { |
| gv = gv.narrow(-1, 0, k); |
| } |
| } |
| auto vh = v.conj().transpose(-2, -1); |
| |
| Tensor sigma_term; |
| if (gsigma.defined()) { |
| gsigma = gsigma.to(self.dtype()); |
| // computes u @ diag(gsigma) @ vh |
| sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh); |
| } else { |
| sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| // in case that there are no gu and gv, we can avoid the series of kernel |
| // calls below |
| if (!gv.defined() && !gu.defined()) { |
| return sigma_term; |
| } |
| |
| auto uh = u.conj().transpose(-2, -1); |
| auto sigma_inv = sigma.pow(-1); |
| auto sigma_sq = sigma.pow(2); |
| auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1); |
| // The following two lines invert values of F, and fills the diagonal with 0s. |
| // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf |
| // first to prevent nan from appearing in backward of this function. |
| F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); |
| F = F.pow(-1); |
| |
| Tensor u_term, v_term; |
| |
| if (gu.defined()) { |
| auto guh = gu.conj().transpose(-2, -1); |
| u_term = at::matmul(u, F.mul(at::matmul(uh, gu) - at::matmul(guh, u)) * sigma.unsqueeze(-2)); |
| if (m > k) { |
| // projection operator onto subspace orthogonal to span(U) defined as I - UU^H |
| auto proj_on_ortho_u = -at::matmul(u, uh); |
| proj_on_ortho_u.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1); |
| u_term = u_term + proj_on_ortho_u.matmul(gu * sigma_inv.unsqueeze(-2)); |
| } |
| u_term = at::matmul(u_term, vh); |
| } else { |
| u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| if (gv.defined()) { |
| auto gvh = gv.conj().transpose(-2, -1); |
| v_term = sigma.unsqueeze(-1) * at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh); |
| if (n > k) { |
| // projection operator onto subspace orthogonal to span(V) defined as I - VV^H |
| auto proj_on_v_ortho = -at::matmul(v, vh); |
| proj_on_v_ortho.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1); |
| v_term = v_term + sigma_inv.unsqueeze(-1) * at::matmul(gvh, proj_on_v_ortho); |
| } |
| v_term = at::matmul(u, v_term); |
| } else { |
| v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| // for complex-valued input there is an additional term |
| // https://giggleliu.github.io/2019/04/02/einsumbp.html |
| // https://arxiv.org/abs/1909.02659 |
| if (self.is_complex() && gu.defined()) { |
| Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1); |
| at::real(L).zero_(); |
| at::imag(L).mul_(sigma_inv); |
| Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh); |
| return u_term + sigma_term + v_term + imag_term; |
| } |
| |
| return u_term + sigma_term + v_term; |
| } |
| |
| // "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation" |
| // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf |
| Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| bool eigenvectors, const Tensor& lambda, const Tensor& v) { |
| // This gradient only works for real eigenvalues at the moment. |
| TORCH_CHECK(eigenvectors, |
| "eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ", |
| "and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)"); |
| auto zeros = at::zeros({1}, lambda.options()); |
| TORCH_CHECK( |
| at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros), |
| "eig_backward: Backward calculation does not support complex eigenvalues at the moment."); |
| |
| auto glambda = grads[0]; |
| auto gv = grads[1]; |
| auto vt = v.transpose(-2, -1); |
| |
| Tensor result; |
| // contribution from the eigenvectors |
| if (gv.defined()) { |
| auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1); |
| |
| auto hm = rlambda.transpose(-2,-1) - rlambda; |
| hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); |
| hm.pow_(-1.0); |
| |
| auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v; |
| auto B = hm * at::matmul(vt, gvortho); |
| auto A = at::matmul(B, vt); |
| |
| std::tie(result, std::ignore) = at::solve(A, vt); |
| } |
| // contribution from eigenvalues |
| if (glambda.defined()) { |
| auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt; |
| auto A = at::matmul(v, grlambda); |
| auto vvt = at::matmul(v, vt); |
| if (result.defined()) { |
| Tensor result1; |
| std::tie(result1, std::ignore) = at::solve(A, vvt); |
| result = result.add(result1); |
| } |
| else { |
| std::tie(result, std::ignore) = at::solve(A, vvt); |
| } |
| } |
| return result; |
| } |
| |
| // http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf |
| Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) { |
| // This gradient is symmetric, and not triangular. |
| // symeig operates only on symmetric inputs, which is a subspace of |
| // R^{n x n}, and hence the derivative is not well-defined for off-diagonal |
| // elements. We resolve this by taking the gradient of the functionally independent |
| // elements of the matrix (i.e., the lower triangular portion of the input) and then |
| // reflect it on the upper triangular portion, thereby symmetrizing the gradient of |
| // the symeig operation. The motivation behind this choice is that symmetric gradient |
| // leads to stable gradient updates, and retains symmetry of the updated matrix if it |
| // were updated by a gradient based algorithm. |
| TORCH_CHECK(eigenvectors, |
| "symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ", |
| "and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)"); |
| |
| auto glambda = grads[0]; |
| auto gv = grads[1]; |
| |
| auto vh = v.conj().transpose(-2, -1); |
| |
| Tensor result; |
| if (gv.defined()) { |
| Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1); |
| F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); |
| F.pow_(-1); |
| result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh)); |
| } else { |
| result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| if (glambda.defined()) { |
| glambda = glambda.to(self.dtype()); |
| // computes v @ diag(glambda) @ vh |
| Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh); |
| if (at::inplaceIsVmapCompatible(result, glambda_term)) { |
| result.add_(glambda_term); |
| } else { |
| result = result + glambda_term; |
| } |
| } |
| return result.add(result.conj().transpose(-2, -1)).mul_(0.5); |
| } |
| |
| Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| std::string mode, const Tensor& q, const Tensor& r){ |
| bool compute_q, reduced; |
| std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); |
| TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. " |
| "Please use torch.linalg.qr(..., mode='reduced')"); |
| |
| auto square_deep_case_backward = [](const Tensor& grad_Q, |
| const Tensor& grad_R, |
| const Tensor& A, |
| const Tensor& Q, |
| const Tensor& R) -> Tensor { |
| // For square and deep (tall) case we refer: |
| // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra. |
| // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition) |
| // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks. |
| // https://arxiv.org/abs/1903.09650 Section 3. QR factorization |
| // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html |
| |
| // Compute R grad_R^H |
| Tensor R_term; |
| if (grad_R.defined()) { |
| R_term = at::matmul(R, grad_R.conj().transpose(-2, -1)); |
| } else { |
| // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N |
| R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| // Compute grad_Q^H Q |
| Tensor Q_term; |
| if (grad_Q.defined()) { |
| Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q); |
| } else { |
| // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N |
| Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| Tensor M = R_term - Q_term; |
| |
| // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity |
| Tensor M_tril = at::tril(M); |
| M = M_tril + M_tril.conj().transpose(-2, -1); |
| M.diagonal(0, -2, -1).mul_(0.5); |
| |
| Tensor rhs_term; |
| if (grad_Q.defined()) { |
| rhs_term = grad_Q + at::matmul(Q, M); |
| } else { |
| rhs_term = at::matmul(Q, M); |
| } |
| |
| // We want to compute: (rhs_term @ R^{-H}) |
| // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H |
| // Since R is upper triangular, we can do this using |
| // triangular_solve(rhs_term^H, R)^H |
| Tensor grad_A; |
| std::tie(grad_A, std::ignore) = at::triangular_solve( |
| rhs_term.conj().transpose(-2, -1), |
| R, |
| /*upper=*/true, |
| /*transpose=*/false, |
| /*unitriangular=*/false); |
| |
| return grad_A.conj().transpose(-2, -1); |
| }; |
| |
| auto m = self.size(-2); |
| auto n = self.size(-1); |
| |
| TORCH_CHECK( |
| ((m <= n && (!reduced)) || reduced), |
| "The derivative of qr is not implemented when mode='complete' and nrows > ncols."); |
| |
| auto grad_Q = grads[0]; |
| auto grad_R = grads[1]; |
| |
| if (m >= n) { |
| return square_deep_case_backward(grad_Q, grad_R, self, q, r); |
| } else { |
| // For wide (m < n) input matrices A, partition A = [X|Y] and R = [U|V] |
| // X and U are square full rank matrices. We will partition grads, |
| // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y]. |
| // To obtain grad_X we reuse the gradient formula from the square case. |
| // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U), |
| // where grad_Q_prime = grad_Q + Y @ grad_V^H |
| // and grad_Y = Q @ grad_V. |
| // Then concatenate grads to get grad_A = [grad_X | grad_Y]. |
| |
| auto Y = self.narrow(-1, m, n - m); |
| auto U = r.narrow(-1, 0, m); |
| Tensor grad_Y, grad_X, grad_V, grad_Q_prime; |
| |
| if (grad_R.defined()) { |
| grad_V = grad_R.narrow(-1, m, n - m); |
| // reuse grad_R to store grad_U |
| grad_R = grad_R.narrow(-1, 0, m); |
| // grad_Q_prime starts with the value of Y @ grad_V^H |
| grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1)); |
| } else { |
| // when grad_R is not defined then grad_V and grad_Q_prime |
| // get initialized with zeros |
| grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| } |
| |
| if (grad_Q.defined()) { |
| // add the grad_Q term into grad_Q_prime when defined o/w is 0 |
| grad_Q_prime = grad_Q_prime + grad_Q; |
| } |
| // Calculate grad_X using the helper. Grad_R contains the grad_U value |
| grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U); |
| grad_Y = at::matmul(q, grad_V); |
| // Concatenate grad_X and grad_Y to get grad_A. |
| return at::cat({grad_X, grad_Y}, -1); |
| } |
| } |
| |
| // Invertible case is derived from Jacobi's formula, and also can be found at: |
| // http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf |
| Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) { |
| auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { |
| Tensor u, sigma, v; |
| std::tie(u, sigma, v) = self.svd(); |
| auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); |
| return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); |
| }; |
| |
| auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { |
| return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); |
| }; |
| |
| if (self.dim() == 2) { |
| if (det.item<double>() == 0) { |
| return singular_case_backward(grad, self, det); |
| } else { |
| return nonsingular_case_backward(grad, self, det); |
| } |
| } else { |
| auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det)); |
| c10::optional<Tensor> first_nonzero_det_index = nonzero_det_indices[0]; |
| |
| if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular) |
| return nonsingular_case_backward(grad, self, det); |
| } |
| |
| auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0)); |
| c10::optional<Tensor> first_zero_det_index = zero_det_indices[0]; |
| |
| if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular) |
| return singular_case_backward(grad, self, det); |
| } |
| |
| Tensor grad_det = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| // invertible case |
| grad_det.index_put_(/*indices=*/nonzero_det_indices, |
| /*value=*/nonsingular_case_backward(grad.index(nonzero_det_indices), |
| self.index(nonzero_det_indices), |
| det.index(nonzero_det_indices))); |
| |
| // non-invertible case, uses SVD |
| grad_det.index_put_(/*indices=*/zero_det_indices, |
| /*value=*/singular_case_backward(grad.index(zero_det_indices), |
| self.index(zero_det_indices), |
| det.index(zero_det_indices))); |
| |
| return grad_det; |
| } |
| } |
| |
| Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) { |
| auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { |
| Tensor u, sigma, v; |
| std::tie(u, sigma, v) = self.svd(); |
| // logdet = \sum log(sigma) |
| auto gsigma = grad.unsqueeze(-1).div(sigma); |
| return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); |
| }; |
| |
| auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { |
| return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); |
| }; |
| |
| if (self.dim() == 2) { |
| if (logdet.item<double>() != -INFINITY) { |
| return nonsingular_case_backward(grad, self); |
| } else { |
| return singular_case_backward(grad, self); |
| } |
| } else { |
| auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); |
| c10::optional<Tensor> first_finite_logdet_index = finite_logdet_indices[0]; |
| |
| if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) |
| return nonsingular_case_backward(grad, self); |
| } |
| |
| auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); |
| c10::optional<Tensor> first_neginf_logdet_index = neginf_logdet_indices[0]; |
| |
| if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) |
| return singular_case_backward(grad, self); |
| } |
| |
| Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| // invertible case |
| grad_logdet.index_put_(/*indices=*/finite_logdet_indices, |
| /*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices), |
| self.index(finite_logdet_indices))); |
| |
| // non-invertible case, uses SVD |
| grad_logdet.index_put_(/*indices=*/neginf_logdet_indices, |
| /*value=*/singular_case_backward(grad.index(neginf_logdet_indices), |
| self.index(neginf_logdet_indices))); |
| |
| return grad_logdet; |
| } |
| } |
| |
| Tensor slogdet_backward(const Tensor& grad_logabsdet, |
| const Tensor& self, |
| const Tensor& signdet, const Tensor& logabsdet) { |
| auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { |
| Tensor u, sigma, v; |
| // TODO: replace self.svd with linalg_svd |
| std::tie(u, sigma, v) = self.svd(); |
| // sigma has all non-negative entries (also with at least one zero entry) |
| // so logabsdet = \sum log(abs(sigma)) |
| // but det = 0, so backward logabsdet = \sum log(sigma) |
| auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); |
| return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); |
| }; |
| |
| auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { |
| // TODO: replace self.inverse with linalg_inverse |
| return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().conj().transpose(-2, -1); |
| }; |
| |
| if (self.dim() == 2) { |
| bool is_singular = self.is_complex() ? signdet.abs().item<double>() == 0 : signdet.item<double>() == 0; |
| if (is_singular) { |
| return singular_case_backward(grad_logabsdet, self); |
| } else { |
| return nonsingular_case_backward(grad_logabsdet, self); |
| } |
| } else { |
| auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet)); |
| c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0]; |
| |
| if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) |
| return nonsingular_case_backward(grad_logabsdet, self); |
| } |
| |
| auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); |
| c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0]; |
| |
| if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) |
| return singular_case_backward(grad_logabsdet, self); |
| } |
| |
| Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| // invertible case |
| grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices, |
| /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices), |
| self.index(nonzero_signdet_indices))); |
| |
| // non-invertible case, uses SVD |
| grad_slogdet.index_put_(/*indices=*/zero_signdet_indices, |
| /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices), |
| self.index(zero_signdet_indices))); |
| |
| return grad_slogdet; |
| } |
| } |
| |
| // Reference: |
| // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf |
| // Sec. 2.3.1 Matrix inverse product |
| std::tuple<Tensor, Tensor> triangular_solve_backward( |
| const Tensor & grad_x, const Tensor & grad_m, |
| const Tensor & b, const Tensor & a, const Tensor & x, |
| const bool upper, const bool transpose, const bool unitriangular, |
| std::array<bool, 2> output_mask) { |
| Tensor grad_b, grad_a; |
| if (grad_x.defined() || grad_m.defined()) { |
| if (grad_x.defined()) { |
| grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); |
| if (output_mask[1]) { |
| grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj()); |
| if (upper) { |
| grad_a = grad_a.triu((int) unitriangular); |
| } else { |
| grad_a = grad_a.tril(-((int) unitriangular)); |
| } |
| } |
| } |
| if (!grad_a.defined()) { |
| grad_a = at::zeros({1}, a.options()).expand_as(a); |
| } |
| if (!grad_b.defined()) { |
| grad_b = at::zeros({1}, b.options()).expand_as(b); |
| } |
| if (output_mask[1] && grad_m.defined()) { |
| grad_a = grad_a.add(grad_m); |
| } |
| } |
| return std::tuple<Tensor, Tensor>{grad_b, grad_a}; |
| } |
| |
| std::tuple<Tensor, Tensor> cholesky_solve_backward( |
| const Tensor& grad_x, const Tensor& self, |
| const Tensor& input2, const Tensor& result, const bool upper) { |
| Tensor grad_self, grad_input2; |
| if (grad_x.defined()) { |
| grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper); |
| |
| Tensor common_term = at::matmul(grad_self, result.conj().transpose(-2, -1)); |
| common_term = common_term + common_term.conj().transpose(-2, -1); |
| |
| if (upper) { |
| grad_input2 = -at::matmul(input2, common_term); |
| } else { |
| grad_input2 = -at::matmul(common_term, input2); |
| } |
| } |
| return std::tuple<Tensor, Tensor>{grad_self, grad_input2}; |
| } |
| |
| Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) { |
| // Forward is C2R (onesided) |
| // Think of onesided C2R irfft as |
| // 1. fill the other half by conjugate symmetry |
| // 2. inverse C2C ifft |
| // 3. discard the complex dimension |
| // So backward is |
| // 1. R2C rfft (essentially add dummy complex dimension, and dft) |
| // 2. accumulate gradient by conjugate symmetry |
| // since rfft results follow conjugate symmetry, we only need to |
| // double some entries from onesided rfft results, i.e., the ones with |
| // their reflected indices also landing out of the onesided range. So |
| // consider the index of last dim: |
| // i. idx = 0. |
| // Reflected to (N - 0) % N = 0. Not doubled. |
| // ii 0 < idx < floor(N/2) (last). |
| // N > N - idx > ceil(N/2) |
| // Reflected to () |
| // iii. idx = floor(N/2) = N/2 (last) when N even. |
| // Reflected to (N - N/2) % N = N/2. Not doubled. |
| // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd. |
| // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled. |
| // Therefore, needs to double |
| // idx = 1, 2, ..., N/2 - 1 when N even |
| // idx = 1, 2, ..., (N-1)/2 when N odd |
| // that is |
| // idx = 1, 2, ..., N - (floor(N/2) + 1) |
| // = 1, 2, ..., N - onesided_length |
| auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true); |
| |
| auto double_length = grad.size(dim.back()) - gI.size(dim.back()); |
| if (double_length > 0) { // also covers case when signal size is zero |
| gI.narrow(dim.back(), 1, double_length).mul_(2); |
| } |
| return gI; |
| } |
| |
| Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, |
| bool onesided, int64_t last_dim_size) { |
| if (!onesided) { |
| return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false)); |
| } |
| |
| // Forward is R2C (onesided) |
| // Think of onesided R2C rfft as |
| // 1. view as complex numbers (fill complex dim with zeros) |
| // 2. C2C fft |
| // 3. discard half of results |
| // So backward is |
| // 1. fill the other half with zeros (with `zero_grad_shape` below) |
| // (C2C ifft only take twosided inputs so we need to fill here) |
| // 2. inverse C2C ifft |
| // 3. discard the complex dim |
| auto half_sizes = grad.sizes(); |
| at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end()); |
| const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size()); |
| new_grad_shape[last_dim] = last_dim_size; |
| |
| const auto zero_length = last_dim_size - grad.size(dim.back()); |
| auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad; |
| if (zero_length > 0) { |
| complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad); |
| } |
| return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); |
| } |
| |
| // Helper for batchnorm_double_backward |
| Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) { |
| auto r = to_sum.sum(0, keepdim); |
| int64_t start_point_exclusive = keepdim ? 1 : 0; |
| for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) { |
| r = r.sum(dim, keepdim); |
| } |
| return r; |
| } |
| |
| // Helper for batchnorm_double_backward |
| // similar to expand_as below, but doesn't do the expand_as; operates as if |
| // reductions were done with keepdim=True |
| Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) { |
| auto src_expanded = src; |
| while (src_expanded.sizes().size() < target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(1); |
| } |
| if (src_expanded.sizes().size() == target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(0); |
| } |
| return src_expanded; |
| } |
| |
| // Helper for batchnorm_double_backward |
| // because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't |
| // do a straight expansion because it won't follow the broadcasting rules. |
| Tensor expand_as_dim1(const Tensor& src, const Tensor& target) { |
| auto src_expanded = src; |
| while (src_expanded.sizes().size() < target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(1); |
| } |
| return src_expanded.expand_as(target); |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward( |
| const Tensor & input, |
| const c10::optional<Tensor> & gamma, |
| const Tensor & ggI, |
| const Tensor & ggG, |
| const Tensor & ggB, |
| const Tensor & gO, |
| const c10::optional<Tensor> & running_mean, |
| const c10::optional<Tensor> & running_var, |
| bool training, |
| double eps, |
| const c10::optional<Tensor> & save_mean, |
| const c10::optional<Tensor> & save_invstd, |
| std::array<bool,3> output_mask) { |
| |
| bool affine = isDefined(gamma); |
| // TODO: Do we have a ScalarOrTensor type? Would such a thing exist? |
| Tensor gamma_expanded; |
| Tensor ggG_expanded, ggB_expanded; |
| if (affine) { |
| gamma_expanded = expand_as_dim1(*gamma, input); |
| if (ggG.defined()) { |
| ggG_expanded = expand_as_dim1(ggG, input); |
| } |
| if (ggB.defined()) { |
| ggB_expanded = expand_as_dim1(ggB, input); |
| } |
| } else { |
| gamma_expanded = at::ones({}, input.options()); |
| } |
| |
| // define some terms we will reuse |
| auto M = input.size(0); |
| for (auto s : input.sizes().slice(2)) { |
| M *= s; |
| } |
| // for half inputs, save_mean, save_invstd are float (ideally, we would cast |
| // everything else, but not now) |
| auto mu = unsqueeze_dim1(training ? toLegacyTensor(save_mean).to(input.scalar_type()) : toLegacyTensor(running_mean), input); |
| auto input_sub_mu = input - mu; |
| auto sigma2_eps_neg_1_2 = unsqueeze_dim1( |
| training ? toLegacyTensor(save_invstd).to(input.scalar_type()) |
| : toLegacyTensor(running_var).add(Scalar(eps)).pow(-0.5), |
| input); |
| auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2); |
| auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3); |
| |
| // calculate gI |
| auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2; |
| auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu); |
| auto gO_sum = sum_exclude_dim1(gO); |
| |
| Tensor gI; |
| if (ggI.defined() && training) { |
| auto ggI_sum = sum_exclude_dim1(ggI); |
| auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); |
| auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_( |
| (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); |
| auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); |
| auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); |
| auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI); |
| gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t)); |
| } |
| |
| // add contribution of gamma term to gI |
| Tensor gI_G_term; |
| if (affine && ggG.defined()) { |
| if (training) { |
| auto t0 = gO * sigma2_eps_neg_1_2; |
| auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M); |
| auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M); |
| gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2)); |
| gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; |
| } else { |
| gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO; |
| gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; |
| } |
| } |
| |
| // this is the first backward's grad_input |
| auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor { |
| auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M); |
| auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_( |
| input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu)); |
| return h0 * h1; |
| }; |
| |
| // calculate gG |
| Tensor gG; |
| if (affine && ggI.defined()) { |
| if (training) { |
| // gG is just the first backwards with the gamma term removed (then shaped properly) |
| gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options())); |
| gG = sum_exclude_dim1(gG, false); |
| } else { |
| gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false); |
| } |
| } |
| |
| // calculate ggO |
| Tensor ggO; |
| // contribution of input term |
| if (ggI.defined()) { |
| if (training) { |
| ggO = first_back_grad_input(ggI, gamma_expanded); |
| } else { |
| ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded; |
| } |
| } |
| if (ggG.defined()) { |
| auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2; |
| ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term; |
| } |
| if (ggB.defined()) { |
| auto ggO_B_term = ggB_expanded; |
| ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term; |
| } |
| |
| if (output_mask[1] && !gG.defined()) { |
| AT_ASSERTM(affine, "gamma should always be defined when it requires grad"); |
| } |
| |
| return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO}; |
| |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor> |
| infinitely_differentiable_native_layer_norm_backward( |
| const Tensor& dY, |
| const Tensor& dmean, |
| const Tensor& drstd, |
| const Tensor& X, |
| const Tensor& mean, |
| const Tensor& rstd, |
| const c10::optional<Tensor>& gamma, |
| IntArrayRef normalized_shape, |
| double eps, |
| std::array<bool, 3> grad_input_mask) { |
| |
| const int normalized_ndim = normalized_shape.size(); |
| const auto input_shape = X.sizes(); |
| const auto input_ndim = X.dim(); |
| const int axis = input_ndim - normalized_ndim; |
| const int64_t M = |
| c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); |
| const int64_t N = |
| c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend()); |
| |
| Tensor dX; |
| Tensor dgamma; |
| Tensor dbeta; |
| |
| const Tensor X_tensor = X.reshape({M, N}); |
| const Tensor mean_tensor = mean.reshape({M, 1}); |
| const Tensor rstd_tensor = rstd.reshape({M, 1}); |
| const double s = 1.0 / static_cast<double>(N); |
| |
| Tensor dY_tensor; |
| if (dY.defined()) { |
| dY_tensor = dY.reshape({M, N}); |
| } |
| |
| if (grad_input_mask[0]) { |
| Tensor gamma_tensor; |
| if (isDefined(gamma)) { |
| gamma_tensor = gamma->reshape({1, N}); |
| } |
| Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor; |
| Tensor var; |
| Tensor dvar; |
| if (drstd.defined()) { |
| var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0); |
| dvar = -0.5 * rstd_cube * drstd.view({M, 1}); |
| } |
| Tensor ds; |
| Tensor db; |
| if (dY.defined()) { |
| ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor |
| : dY_tensor * X_tensor) |
| .sum(1) |
| .unsqueeze_(-1); |
| db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor) |
| .sum(1) |
| .unsqueeze_(-1); |
| const Tensor& a = rstd_tensor; |
| const Tensor b = (db * mean_tensor - ds) * rstd_cube * s; |
| const Tensor c = -b * mean_tensor - db * rstd_tensor * s; |
| if (isDefined(gamma)) { |
| dX = a * dY_tensor * gamma_tensor + b * X_tensor + c; |
| } else { |
| dX = a * dY_tensor + b * X_tensor + c; |
| } |
| if (dmean.defined() && drstd.defined()) { |
| dX += var_std_mean_backward( |
| {dvar, dmean.view({M, 1})}, |
| X_tensor, |
| var, |
| mean_tensor, |
| {1}, |
| false, |
| true, |
| false); |
| } |
| dX = dX.reshape_as(X); |
| } else if (dmean.defined() && drstd.defined()) { |
| dX = var_std_mean_backward( |
| {dvar, dmean.view({M, 1})}, |
| X_tensor, |
| var, |
| mean_tensor, |
| {1}, |
| false, |
| true, |
| false) |
| .reshape_as(X); |
| } |
| } |
| |
| if (grad_input_mask[1] && dY.defined()) { |
| dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor) |
| .sum(0) |
| .reshape_as(toLegacyTensor(gamma)); |
| } |
| if (grad_input_mask[2] && dY.defined()) { |
| dbeta = dY_tensor.sum(0).reshape_as(toLegacyTensor(gamma)); |
| } |
| |
| return std::make_tuple(dX, dgamma, dbeta); |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor> |
| infinitely_differentiable_native_group_norm_backward( |
| const Tensor& dY, |
| const Tensor& dmean, |
| const Tensor& drstd, |
| const Tensor& X, |
| const Tensor& mean, |
| const Tensor& rstd, |
| const c10::optional<Tensor>& gamma, |
| int64_t N, |
| int64_t C, |
| int64_t HxW, |
| int64_t group, |
| double eps, |
| std::array<bool, 3> grad_input_mask) { |
| const int64_t G = group; |
| const int64_t D = C / G; |
| const double s = 1.0 / static_cast<double>(D * HxW); |
| Tensor dX; |
| Tensor dgamma; |
| Tensor dbeta; |
| const Tensor X_tensor = X.reshape({N, G, D, HxW}); |
| const Tensor mean_tensor = mean.reshape({N, G, 1, 1}); |
| const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1}); |
| Tensor dY_tensor; |
| Tensor ds; |
| Tensor db; |
| if (dY.defined()) { |
| dY_tensor = dY.reshape({N, G, D, HxW}); |
| ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1); |
| db = dY_tensor.sum(3).unsqueeze_(-1); |
| } |
| if (grad_input_mask[0]) { |
| Tensor gamma_tensor; |
| if (isDefined(gamma)) { |
| gamma_tensor = gamma->reshape({1, G, D, 1}); |
| } |
| const Tensor var = |
| ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0); |
| const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor; |
| Tensor dvar; |
| if (drstd.defined()) { |
| dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1}); |
| } |
| if (dY.defined()) { |
| const Tensor a = |
| isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor; |
| Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2)) |
| .unsqueeze_(-2); |
| Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2)) |
| .unsqueeze_(-2); |
| b = (c * mean_tensor - b) * rstd_cube * s; |
| c = -b * mean_tensor - c * rstd_tensor * s; |
| dX = a * dY_tensor + b * X_tensor + c; |
| if (dmean.defined() && drstd.defined()) { |
| dX += var_std_mean_backward( |
| {dvar, dmean.view({N, G, 1, 1})}, |
| X_tensor, |
| var, |
| mean_tensor, |
| {2, 3}, |
| false, |
| true, |
| false); |
| } |
| dX = dX.reshape_as(X); |
| } else if (dmean.defined() && drstd.defined()) { |
| dX = var_std_mean_backward( |
| {dvar, dmean.view({N, G, 1, 1})}, |
| X_tensor, |
| var, |
| mean_tensor, |
| {2, 3}, |
| false, |
| true, |
| false) |
| .reshape_as(X); |
| } |
| } |
| if (grad_input_mask[1] && dY.defined()) { |
| dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toLegacyTensor(gamma)); |
| } |
| if (grad_input_mask[2] && dY.defined()) { |
| dbeta = db.sum(0).reshape_as(toLegacyTensor(gamma)); |
| } |
| |
| return std::make_tuple(dX, dgamma, dbeta); |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3, |
| IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, |
| IntArrayRef sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) { |
| Tensor grad_i1, grad_i2, grad_i3; |
| if (grad_out.defined()) { |
| if (grad_mask[0]) |
| grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1); |
| if (grad_mask[1]) |
| grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2); |
| if (grad_mask[2]) |
| grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3); |
| } |
| return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3); |
| } |
| |
| Tensor log1p_backward(const Tensor& grad, const Tensor& self) { |
| if (self.is_sparse()) { |
| AT_ERROR( |
| "log1p of a sparse tensor is made to be non-differentiable since ", |
| "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ", |
| "Use a different mathematical operation which preserves sparsity of gradients, ", |
| "or report a bug if you think this is an error."); |
| } |
| return grad / (self + 1).conj(); |
| } |
| |
| Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices) { |
| return _sparse_mask_helper(sparse_grad_out.coalesce(), indices.contiguous()); |
| } |
| |
| // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads]) |
| Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { |
| auto negated_pad = pad.vec(); |
| std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>()); |
| return at::constant_pad_nd(grad, negated_pad, 0); |
| } |
| |
| Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { |
| // since first backward takes care of scaling by frequency, |
| // we don't need to worry about it here. |
| auto gg_weight = grad.index_select(0, indices.reshape(-1)); |
| |
| // reshape gradient as per the shape of indices |
| auto size = indices.sizes().vec(); |
| size.push_back(-1); |
| |
| if (padding_idx >= 0) { |
| gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); |
| } |
| return gg_weight.view(size); |
| } |
| |
| Tensor index_backward(Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const Tensor& grad) { |
| return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); |
| } |
| |
| Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { |
| if (zero_infinity) { |
| return at::where( |
| loss.unsqueeze(0).unsqueeze(2) == 0, |
| at::zeros({0}, raw_grad.options()), |
| raw_grad * grad_out.unsqueeze(0).unsqueeze(2)); |
| } else { |
| return raw_grad * grad_out.unsqueeze(0).unsqueeze(2); |
| } |
| } |
| |
| bool any_variable_defined(variable_list& variables) { |
| for (auto variable : variables) { |
| if (variable.defined()) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| std::tuple<Tensor, Tensor> polar_backward( |
| const Tensor& grad, |
| const Tensor& result) { |
| Tensor grad_abs, grad_angle; |
| if (grad.defined()) { |
| auto grad_conj = grad.conj(); |
| grad_abs = at::real(grad_conj * at::sgn(result)); |
| auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0}); |
| grad_angle = at::real(grad_conj * result_mul_1_j); |
| } |
| return std::make_tuple(grad_abs, grad_angle); |
| } |
| |
| } // namespace details |
| } // namespace generated |
| } // namespace autograd |
| } // namespace torch |