| #include "Functions.h" |
| #include <ATen/WrapDimUtils.h> |
| #include <iostream> |
| |
| // define constants like M_PI and C keywords for MSVC |
| #ifdef _MSC_VER |
| #define _USE_MATH_DEFINES |
| #include <ciso646> |
| #endif |
| #include <math.h> |
| #include <algorithm> |
| |
| // ${generated_comment} |
| |
| using at::Tensor; |
| using at::Scalar; |
| using at::IntList; |
| using at::TensorList; |
| |
| namespace torch { namespace autograd { namespace generated { |
| |
| namespace { |
| |
| // Helper functions for autogenerated code |
| |
| // A simple way to imperatively compute index ranges for slots |
| // that have been flattened |
| struct IndexRangeGenerator { |
| IndexRange range(size_t range_size) { |
| i += range_size; |
| return {i - range_size, i}; |
| } |
| size_t size() { return i; } |
| private: |
| size_t i = 0; |
| }; |
| |
| void copy_range(variable_list& out, IndexRange range, const Tensor & t) { |
| AT_ASSERT(range.second <= out.size(), "range out of bounds"); |
| AT_ASSERT(range.second - range.first == 1, "inconsistent range for Tensor output"); |
| out[range.first] = t; |
| } |
| |
| void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) { |
| AT_ASSERT(range.second <= out.size(), "range out of bounds"); |
| AT_ASSERT(range.second - range.first == t.size(), "inconsistent range for TensorList output"); |
| std::copy(t.begin(), t.end(), out.begin() + range.first); |
| } |
| |
| Tensor not_implemented(const char* name) { |
| throw std::runtime_error( |
| std::string("the derivative for '") + name + "' is not implemented"); |
| } |
| |
| Tensor maybe_multiply(const Tensor & t, const Scalar & s) { |
| bool is_one = false; |
| if (s.isFloatingPoint()) { |
| is_one = s.toDouble() == 1; |
| } else if(s.isIntegral()) { |
| is_one = s.toLong() == 1; |
| } |
| |
| if (is_one) { |
| return t; |
| } else { |
| return t * s; |
| } |
| } |
| |
| int64_t _safe_size(IntList sizes, int64_t dim) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| return sizes.size() != 0 ? sizes[dim] : 1; |
| } |
| |
| Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) { |
| double p = p_.toDouble(); |
| Tensor self_scaled; |
| Tensor scale_v; |
| if (p == 0.0) { |
| return zeros_like(self); |
| } else if (p == 1.0) { |
| return self.sign() * grad; |
| } else if (p < 2.0) { |
| self_scaled = self.sign() * self.abs().pow(p - 1); |
| scale_v = grad / norm.pow(p - 1); |
| } else if (p == 2.0) { |
| self_scaled = self; |
| scale_v = grad / norm; |
| } else { |
| self_scaled = self * self.abs().pow(p - 2); |
| scale_v = grad / norm.pow(p - 1); |
| } |
| // handle case at 0 where we return a subgradient containing 0 |
| scale_v.masked_fill_(norm == 0, 0); |
| return self_scaled * scale_v; |
| } |
| |
| Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor norm, int64_t dim, bool keepdim) { |
| if (!keepdim && self.dim() != 0) { |
| grad = grad.unsqueeze(dim); |
| norm = norm.unsqueeze(dim); |
| } |
| return norm_backward(grad, self, p_, norm); |
| } |
| |
| Tensor reduce_to(const Tensor & grad, IntList sizes) { |
| if (sizes.size() == 0) { |
| return grad.sum(); |
| } |
| Tensor result = grad; |
| while (result.dim() > (int64_t)sizes.size()) { |
| result = result.sum(0, false); |
| } |
| for (int64_t i = 0; i < result.dim(); ++i) { |
| if (sizes[i] == 1 && result.sizes()[i] > 1) { |
| result = result.sum(i, true); |
| } |
| } |
| return result; |
| } |
| |
| Tensor permute_backwards(const Tensor & grad, IntList fwd_dims) { |
| // invert the permutation |
| std::vector<int64_t> dims(fwd_dims.size()); |
| for (size_t i = 0; i < fwd_dims.size(); i++) { |
| dims[fwd_dims[i]] = i; |
| } |
| return grad.permute(dims); |
| } |
| |
| Tensor sum_backward(const Tensor & grad, IntList sizes, int64_t dim, bool keepdim) { |
| if (!keepdim && sizes.size() > 0) { |
| return grad.unsqueeze(dim).expand(sizes); |
| } else { |
| return grad.expand(sizes); |
| } |
| } |
| |
| Tensor reverse_dim(const Tensor& t, int64_t dim) { |
| Tensor index = at::arange(t.type().toScalarType(at::ScalarType::Long), t.size(dim) - 1, -1, -1); |
| return t.index_select(dim, index); |
| } |
| |
| Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) { |
| if (inp.size(dim) == 1) { |
| return grad; |
| } |
| |
| std::vector<int64_t> ones_size(inp.sizes()); |
| ones_size[dim] = 1; |
| Tensor ones = at::ones(grad.type(), ones_size); |
| Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim); |
| Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim); |
| |
| Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim); |
| Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); |
| Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); |
| |
| return grad * (exclusive_normal * exclusive_reverse); |
| } |
| |
| // note that the gradient for prod is equivalent to: |
| // cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.: |
| // input: [ a, b, c] |
| // cumprod(exclusive, normal): [1 , a, a * b] |
| // cumprod(exclusive, reverse): [b * c, c, 1] |
| // product: [b * c, a * c, a * b] |
| // and this is safe under input with 0s. |
| Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) { |
| if (input.dim() == 0) { |
| return grad; |
| } |
| Tensor zero_idx = (input == 0).nonzero(); |
| if (zero_idx.numel() == 0) { |
| return (grad * result) / input; |
| } else if (zero_idx.size(0) > 1) { |
| return zeros_like(input); |
| } else { |
| return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input); |
| } |
| } |
| |
| Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) { |
| if (input.dim() == 0) { |
| return grad; |
| } |
| dim = at::maybe_wrap_dim(dim, input.sizes().size()); |
| if (!keepdim && input.dim() != 1) { |
| grad = grad.unsqueeze(dim); |
| result = result.unsqueeze(dim); |
| } |
| |
| Tensor zero_mask = (input == 0); |
| Tensor slice_zero_count = zero_mask.sum(dim, true); |
| int64_t total_zeros = slice_zero_count.sum().toCLong(); |
| if (total_zeros == 0) { |
| return (grad * result) / input; |
| } else { |
| return prod_safe_zeros_backward(grad, input, dim); |
| } |
| } |
| |
| Tensor sum_scan_exclusive(const Tensor& x, int64_t dim) { |
| Tensor ret = at::cumsum(-x, dim); |
| |
| int64_t end_idx = ret.size(dim) - 1; |
| Tensor ret_sum = ret.narrow(dim, end_idx, 1).clone(); |
| ret -= ret_sum.expand_as(ret); |
| ret += x; |
| return ret; |
| } |
| |
| Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim) { |
| /* |
| There are two algorithms to do this. The first one |
| is very efficient, but works only when there are no |
| nonzero elements in the input. |
| |
| The second one is much more complex, but it doesn't |
| assume anything on the input. The main downside is |
| that it takes time O(n^2), where n = input.size(self.dim) |
| (i.e. the length of the cumulative product). This is in |
| contrast to the forward pass and the efficient algorithm, |
| which are both O(n). |
| |
| The second algorithm is a simple application of the chain |
| rule. If x is an n-dimensional vector, and y = cumprod(x), |
| and F is the final cost, then |
| |
| dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1) |
| |
| The term dF / dy_j is just grad_output[j] (assuming again |
| everything is one-dimensional). |
| |
| The term (dy_j / dx_k) is easilly seen to be |
| |
| if j >= k |
| dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i |
| else: |
| dy_j / dx_k = 0 |
| |
| Note that the indicator (j>=k) can be taken out |
| by replacing the sum in (1) with a sum from |
| j = k to n. |
| |
| Thus, |
| df / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k) |
| |
| with |
| dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2) |
| |
| Note that this last term is just the cumulative product |
| with k omitted. Thus, if x_k (the input) is nonzero, we can |
| just express this as |
| |
| dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k |
| = y_j / x_k |
| |
| So therefore, |
| |
| df / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k |
| |
| so |
| |
| grad_output = sum_scan_exclusiv(grad_output * output) / input |
| |
| If the input is nonzero, we need to calculate the dy_j / dx_k |
| by using the formula (2), called in the code omitted_products. |
| |
| The way the code calculates it is simply by noting that |
| |
| prod_{1 <= i <= j, i != k} x_i |
| = (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i) |
| |
| the first term is calculated as prods_until_k, which since |
| doesn't depend in j is easy to vectorize. |
| |
| The second term (indexed by j) is the cumulative product of |
| x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code |
| prods_from_k_pkus_1, and it's calculated as a cumprod. |
| |
| In order to vectorize this properly, we need to add to |
| omitted_products the dimensions where k > j, and therefore |
| dy_j / dx_k = 0, which is done right after the assert. |
| */ |
| |
| if (input.dim() == 0) { |
| return grad; |
| } |
| dim = at::maybe_wrap_dim(dim, input.sizes().size()); |
| int64_t dim_size = input.size(dim); |
| if (dim_size == 1) { |
| return grad; |
| } |
| |
| // Simple case with nonzero elements in the input |
| if ((input != 0).all().toCByte()) { |
| Tensor result = at::cumprod(input, dim); |
| return sum_scan_exclusive(result * grad, dim) / input; |
| } |
| |
| std::vector<int64_t> ones_size(input.sizes()); |
| ones_size[dim] = 1; |
| Tensor ones = at::ones(grad.type(), {1}).expand(ones_size); |
| Tensor grad_input = at::zeros(grad.type(), input.sizes()); |
| Tensor prods_from_k_plus_1; |
| Tensor omitted_products; |
| for (int k = 0; k < dim_size; ++k) { |
| if (k == 0) { |
| prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim); |
| omitted_products = at::cat({ones, prods_from_k_plus_1}, dim); |
| } else if (k == dim_size - 1) { |
| Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true); |
| omitted_products = prods_until_k; |
| } else { |
| Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true); |
| prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim); |
| omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1; |
| omitted_products = at::cat({prods_until_k, omitted_products}, dim); |
| } |
| |
| // At this point omitted_products is the same size |
| // as input, except on the dimension dim where it's |
| // dim_size - k |
| TORCH_ASSERT(omitted_products.size(dim) == dim_size - k); |
| |
| grad_input.select(dim, k).copy_( |
| at::sum(grad.slice(dim, k) * omitted_products,dim)); |
| } |
| |
| return grad_input; |
| } |
| |
| Tensor cumsum_backward(const Tensor & x, int64_t dim) { |
| if (x.dim() == 0) { |
| return x; |
| } |
| auto ret = at::cumsum(-x, dim); |
| auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(); |
| ret -= ret_sum.expand(ret.sizes()); |
| ret += x; |
| return ret; |
| } |
| |
| Tensor unsqueeze_to(const Tensor & self, IntList sizes) { |
| auto result = self; |
| |
| int64_t nDims = sizes.size(); |
| for (int64_t dim = 0; dim < nDims; dim++) { |
| if (sizes[dim] == 1) { |
| result = result.unsqueeze(dim); |
| } |
| } |
| return result; |
| } |
| |
| Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntList sizes) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided |
| // unsqueezing in the backward. |
| if (sizes.size() > 0 && sizes[dim] == 1) { |
| return self.unsqueeze(dim); |
| } |
| return self; |
| } |
| |
| std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<int64_t> &sizes, int64_t dim) { |
| std::vector<Tensor> grad_inputs(sizes.size()); |
| int64_t accumulate = 0; |
| for (size_t i = 0; i < sizes.size(); ++i) { |
| auto size = sizes[i]; |
| accumulate += size; |
| if (size == 0) { |
| grad_inputs[i] = at::zeros(grad.type(), {0}); |
| } else { |
| grad_inputs[i] = grad.narrow(dim, accumulate - size, size); |
| } |
| } |
| return grad_inputs; |
| } |
| |
| Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, IntList sizes, IntList strides, const Scalar & alpha) { |
| // if input was column-major, return grad as column-order for efficiency |
| if (strides[0] == 1 && strides[1] == sizes[0]) { |
| return maybe_multiply(mat2.mm(grad.t()).t(), alpha); |
| } else { |
| return maybe_multiply(grad.mm(mat2.t()), alpha); |
| } |
| } |
| |
| Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntList sizes, IntList strides, const Scalar & alpha) { |
| // if input was column-major, return grad as column-order for efficiency |
| if (strides[0] == 1 && strides[1] == sizes[0]) { |
| return maybe_multiply(grad.t().mm(mat1).t(), alpha); |
| } else { |
| return maybe_multiply(mat1.t().mm(grad), alpha); |
| } |
| } |
| |
| Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { |
| auto transposed_sizes = std::vector<int64_t>(self.transpose(dim, 0).sizes()); |
| auto flatten = [&](const Tensor & t) { |
| return t.transpose(dim, 0).contiguous().view({t.size(dim), -1}); |
| }; |
| auto unflatten = [&](const Tensor & t) { |
| return t.contiguous().view(transposed_sizes).transpose(dim, 0); |
| }; |
| |
| // renorm computes the norm over all dimensions except `dim`, which is why |
| // we need the flatten and unflatten business. TODO: simplify this when we |
| // add support for norm over multiple dimensions. |
| auto self_flat = flatten(self); |
| auto grad_flat = flatten(grad); |
| auto norm_flat = self_flat.norm(p, 1, true); |
| auto grad_output = (self_flat * grad_flat).sum(1, true); |
| auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true); |
| auto invnorm = (norm_flat + 1e-7).reciprocal(); |
| auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb)); |
| auto norm = unflatten(norm_flat.expand_as(self_flat)); |
| |
| // TODO: remove the detach once comparison ops no longer require grad |
| auto mask = Variable(norm < maxnorm).detach(); |
| return at::where(mask, grad, grad_norm); |
| } |
| |
| Tensor sum_tensorlist(TensorList tl) { |
| if (tl.size() == 0) { |
| throw std::runtime_error("Can't sum tensorlist of size 0"); |
| } |
| Tensor sum = tl[0]; |
| for(size_t i = 1; i < tl.size(); ++i) { |
| sum = sum + tl[i]; |
| } |
| return sum; |
| } |
| |
| Tensor repeat_backward(Tensor grad, int64_t input_dims, IntList repeats) { |
| int64_t num_unsqueezed = grad.dim() - input_dims; |
| for (int64_t i = 0; i < num_unsqueezed; ++i) { |
| grad = grad.sum(0, false); |
| } |
| for (size_t j = num_unsqueezed; j < repeats.size(); ++j) { |
| int64_t repeat = repeats[j]; |
| if (repeat == 1) { |
| continue; |
| } |
| int64_t dim = j - num_unsqueezed; |
| grad = sum_tensorlist(grad.chunk(repeat, dim)); |
| } |
| return grad; |
| } |
| |
| Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor & value) { |
| auto grad_input = zeros_like(input); |
| grad_input.masked_fill_(input == value, grad); |
| return grad_input; |
| } |
| |
| Tensor select_backward(Tensor grad, int64_t dim, Tensor indices, IntList sizes, bool keepdim) { |
| if (!keepdim) { |
| grad = grad.unsqueeze(dim); |
| indices = indices.unsqueeze(dim); |
| } |
| return at::zeros(grad.type(), sizes).scatter_(dim, indices, grad); |
| } |
| |
| Tensor trace_backward(const Tensor & grad, IntList sizes) { |
| if (sizes.size() != 2) { |
| throw std::runtime_error("expected matrix input"); |
| } |
| |
| auto& long_type = grad.type().toScalarType(at::kLong); |
| |
| auto grad_input = at::zeros(grad.type(), sizes[0] * sizes[1]); |
| auto indices = at::arange(long_type, 0, grad_input.numel(), sizes[1] + 1); |
| grad_input.index_fill_(0, indices, grad); |
| return grad_input.view(sizes); |
| } |
| |
| Tensor unfold_backward(const Tensor & grad, IntList input_sizes, int64_t dim, int64_t size, int64_t step) { |
| auto& long_type = grad.type().toScalarType(at::kLong); |
| |
| int64_t numel = 1; |
| for (auto size : input_sizes) { |
| numel *= size; |
| } |
| |
| auto idx = at::arange(long_type, 0, numel).view(input_sizes); |
| auto idx_unfolded = idx.unfold(dim, size, step).contiguous().view(-1); |
| auto grad_input = at::zeros(grad.type(), {numel}); |
| grad_input.index_add_(0, idx_unfolded, grad.contiguous().view(-1)); |
| return grad_input.view(input_sizes); |
| } |
| |
| Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) { |
| return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean()); |
| } |
| |
| Tensor var_backward(Tensor grad, const Tensor & self, int64_t dim, bool unbiased, bool keepdim) { |
| if (self.dim() == 0) { |
| return var_backward(grad, self, unbiased); |
| } |
| if (!keepdim && self.dim() > 1) { |
| grad = grad.unsqueeze(dim); |
| } |
| return (2.0 / (self.size(dim) - unbiased)) * grad * (self - self.mean(dim, true)); |
| } |
| |
| Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList sizes) { |
| int64_t numel = 1; |
| for (auto size : sizes) { |
| numel *= size; |
| } |
| auto mask_selected = grad.masked_select(mask); |
| auto diff_nelem = numel - mask_selected.numel(); |
| if (diff_nelem > 0) { |
| // because mask_selected returns a 1-d tensor with size of masked elements that are 1, |
| // we need to fill out the rest with zeros then reshape back to tensor2's size. |
| auto zeros_fillin = at::zeros(grad.type(), {diff_nelem}); |
| mask_selected = at::cat({mask_selected, zeros_fillin}, 0); |
| } |
| return mask_selected.view(sizes); |
| } |
| |
| Tensor potrf_backward(Tensor grad, bool upper, Tensor L) { |
| // cf. Iain Murray (2016); arXiv 1602.07527 |
| if (upper) { |
| L = L.t(); |
| grad = grad.t(); |
| } |
| |
| auto phi = [](const Tensor & A) -> Tensor { |
| auto B = A.tril(); |
| B = B - 0.5 * at::diag(at::diag(B)); |
| return B; |
| }; |
| |
| // make sure not to double-count variation, since |
| // only half of output matrix is unique |
| auto Lbar = grad.tril(); |
| |
| auto P = phi(at::mm(L.t(), Lbar)); |
| Tensor S; |
| std::tie(S, std::ignore) = at::gesv(P + P.t(), L.t()); |
| std::tie(S, std::ignore) = at::gesv(S.t(), L.t()); |
| S = phi(S); |
| if (upper) { |
| S = S.t(); |
| } |
| return S; |
| } |
| |
| Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads, |
| IntList split_sizes, int64_t dim, IntList sizes, const Type &type) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| |
| // it's possible some of the grads are not defined (represents tensors of all 0s). |
| // Since at::cat can't handle those, let's define them |
| std::vector<Tensor> grads_all_defined(grads.size()); |
| for (size_t j = 0; j < grads.size(); ++j) { |
| if (grads[j].defined()) { |
| grads_all_defined[j] = grads[j]; |
| } else { |
| auto length = split_sizes[j]; |
| std::vector<int64_t> grad_size(sizes); |
| grad_size[dim] = length; |
| grads_all_defined[j] = at::zeros(type, grad_size); |
| } |
| } |
| |
| auto ret = at::cat(grads_all_defined, dim); |
| return ret; |
| } |
| |
| Tensor split_backward(const std::vector<torch::autograd::Variable> &grads, |
| int64_t split_size, int64_t dim, IntList sizes, const Type &type) { |
| dim = at::maybe_wrap_dim(dim, sizes.size()); |
| int64_t dim_size = sizes[dim]; |
| int64_t num_splits = grads.size(); |
| std::vector<int64_t> split_sizes(num_splits, split_size); |
| split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); |
| return split_with_sizes_backward(grads, split_sizes, dim, sizes, type); |
| } |
| |
| Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { |
| TORCH_ASSERT(indices.dim() >= dim); |
| auto size = std::vector<int64_t>(indices.sizes().slice(0, indices.dim() - dim)); |
| size.push_back(-1); |
| auto indices_view = indices.view(size); |
| return grad.contiguous().view(size).gather(-1, indices_view).view(indices.sizes()); |
| } |
| |
| Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) { |
| auto& gO = grad_output; |
| auto input_size = input.size(dim) / 2; |
| auto first_half = input.narrow(dim, 0, input_size); |
| auto second_half = input.narrow(dim, input_size, input_size); |
| auto sig_second_half = second_half.sigmoid(); |
| auto one_sub_sig_second_half = 1 - sig_second_half; |
| auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half; |
| |
| auto ggI_first_half = grad.narrow(dim, 0, input_size); |
| auto ggI_second_half = grad.narrow(dim, input_size, input_size); |
| auto ggI_second_half_times_first_half = ggI_second_half * first_half; |
| |
| auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig; |
| auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig; |
| auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig; |
| return at::cat({gI_first_half, gI_second_half}, dim); |
| } |
| |
| Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) { |
| if (dim < 0) dim += input.dim(); |
| std::vector<int64_t> sizes = input.sizes(); |
| sizes[dim] /= 2; |
| auto tmp = grad * glu_backward(at::ones(input.type(), sizes), input, dim); |
| return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]); |
| } |
| |
| Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| auto result = kl_div_backward(grad, input, target, size_average, false); |
| if (reduce && size_average) { |
| return result.mean(); |
| } else if (reduce) { |
| return result.sum(); |
| } |
| return result; |
| } |
| |
| Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { |
| auto z = input.sigmoid(); |
| return grad * (z - 1) * z; |
| } |
| |
| Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { |
| auto gO = grad_output; |
| auto ggI = grad; |
| |
| auto ggI_output = ggI * output; |
| auto ggI_out_sum = ggI_output.sum(dim, true); |
| auto ggI_out_sum_output = ggI_out_sum * output; |
| auto gO_out_sum = (gO * output).sum(dim, true); |
| |
| // gI calculation |
| auto gI_t0 = ggI_output * (gO - gO_out_sum); |
| auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum)); |
| auto gI_t2 = ggI_out_sum_output * gO; |
| auto gI_t3 = ggI_out_sum_output * gO_out_sum; |
| return gI_t0 - gI_t1 - gI_t2 + gI_t3; |
| } |
| |
| Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { |
| auto z = output.exp(); |
| return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); |
| } |
| |
| Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| auto output = l1_loss_backward(grad, input, target, size_average, false); |
| if (reduce and size_average) { |
| return output.mean(); |
| } else if (reduce) { |
| return output.sum(); |
| } |
| return output; |
| } |
| |
| Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| auto d = (input - target).abs(); |
| auto grad_input = grad * (d < 1).toType(grad.type()); |
| if (size_average && reduce) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| if (!reduce) { |
| return smooth_l1_loss_backward(grad, input, target, size_average, reduce); |
| } |
| auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, size_average, true); |
| return (r * grad).sum(); |
| } |
| |
| static inline int64_t diag_size(int64_t height, int64_t width, int64_t diagonal) { |
| if (width > height) { |
| return diag_size(width, height, -diagonal); |
| } |
| // Assumes height >= width |
| auto longest_diag = width; |
| if (diagonal >= 0) { |
| return longest_diag - diagonal; |
| } |
| if (longest_diag < height + diagonal) { |
| return longest_diag; |
| } |
| return height + diagonal; |
| } |
| |
| Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) { |
| auto ndimension = input_sizes.size(); |
| TORCH_ASSERT(ndimension == 1 || ndimension == 2); |
| |
| if (ndimension == 1 || input_sizes[0] == input_sizes[1]) { |
| return grad.diag(diagonal); |
| } |
| |
| // Input was a matrix but was not square |
| auto grad_input = at::zeros(grad.type(), input_sizes); |
| auto diagonal_size = diag_size(input_sizes[0], input_sizes[1], diagonal); |
| auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1]; |
| auto diag = grad_input.as_strided({diagonal_size}, {input_sizes[1] + 1}, storage_offset); |
| diag.copy_(grad); |
| return grad_input; |
| } |
| |
| Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, bool size_average, bool reduce) { |
| auto grad_input = 2 * grad; |
| if (size_average && reduce) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| if (!reduce) { |
| return mse_loss_backward(grad, input, target, size_average, reduce); |
| } |
| auto r = mse_loss_backward(ones_like(grad_output), input, target, size_average, true); |
| return (r * grad).sum(); |
| } |
| |
| Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| auto z = (input * -target).exp(); |
| auto zplus1 = z + 1; |
| auto grad_input = grad * (target * target) * z / (zplus1 * zplus1); |
| if (size_average && reduce) { |
| grad_input /= input.numel(); |
| } |
| return grad_input; |
| } |
| |
| Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, bool size_average, bool reduce) { |
| if (!reduce) { |
| return soft_margin_loss_backward(grad, input, target, size_average, reduce); |
| } |
| auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, size_average, true); |
| return (r * grad).sum(); |
| } |
| |
| Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) { |
| auto x = (input * beta); |
| return _sigmoid_backward(grad, x.sigmoid()) * (x < threshold).toType(grad.type()) * beta; |
| } |
| |
| Tensor as_strided_backward(const Tensor & grad, TensorGeometry base, IntList sizes, IntList strides, int64_t storage_offset) { |
| auto src = base.zeros_with_stride(grad.type()); |
| src.as_strided(sizes, strides, storage_offset - base.storage_offset()).copy_(grad); |
| return src; |
| } |
| |
| std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) { |
| auto recip = (self * self + other * other).reciprocal(); |
| return std::tuple<Tensor,Tensor>{ |
| output_mask[0] ? grad * other * recip : Tensor(), |
| output_mask[1] ? grad * -self * recip : Tensor() }; |
| } |
| |
| // TODO: Seriously consider writing the derivative formulas for |
| // each output separately; there is not all that much sharing |
| // of computation going on here. |
| std::tuple<Tensor, Tensor, Tensor> prelu_double_backward( |
| const Tensor & mb_ggI, |
| const Tensor & mb_ggW, |
| const Tensor & mb_gO, |
| const Tensor & input, |
| const Tensor & weight, |
| std::array<bool, 3> output_mask) { |
| |
| // Zero-fill undefined grads (TODO: do this more efficiently) |
| auto ggI = mb_ggI.defined() ? mb_ggI : input.type().zeros_like(input); |
| auto ggW = mb_ggW.defined() ? mb_ggW : weight.type().zeros_like(weight); |
| auto gO = mb_gO.defined() ? mb_gO : input.type().zeros_like(input); |
| |
| auto positive_mask = (input > 0).type_as(ggI); |
| auto nonpositive_mask = (input <= 0).type_as(ggW); |
| |
| // Explanation: Let input be i, weight be w, grad_output be gO. |
| // f(i, w) = i if i > 0 |
| // = wi if i <= 0 |
| // df/di * gO = gO if i > 0 df/dw * g0 = 0 if i > 0 |
| // = g0 * w if i <= 0 = g0 * i if i <= 0 |
| // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly. |
| |
| if (weight.numel() == 1) { |
| // from PReLU.forward: num_parameters == 0 is used indicate that a |
| // single weight is shared among all input channels. |
| |
| // this is a little tricky because PReLU currently doesn't take a shape so the weight may be |
| // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable |
| // and tensor are merged). So, use weight and ggW as 0-dim in this case. |
| bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1); |
| auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight; |
| auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; |
| |
| auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input); |
| auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); |
| return std::tuple<Tensor, Tensor, Tensor>( |
| ggO, |
| ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, |
| (ggI * gO * nonpositive_mask).sum() |
| ); |
| } else { |
| // Expand ggW to match size of ggI; a simple expand doesn't work because |
| // ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example, |
| // let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1) |
| // so the expand succeeds. |
| auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0); |
| auto ggW_expanded = ggW; |
| for (int64_t i = 0; i < dims_to_unsqueeze; i++) { |
| ggW_expanded = ggW_expanded.unsqueeze(1); |
| } |
| ggW_expanded = ggW_expanded.expand_as(ggI); |
| |
| auto gI = ggW_expanded * gO * nonpositive_mask; |
| |
| auto gW = ggI * gO * nonpositive_mask; |
| if (input.dim() > 1) { |
| gW = gW.sum(0); |
| } |
| while (gW.dim() > 1) { |
| gW = gW.sum(1); |
| } |
| |
| Tensor ggO; |
| if (output_mask[0]) { |
| // expand weight as input as in ggW/ggI above |
| auto weight_expanded = weight; |
| for (int64_t i = 0; i < dims_to_unsqueeze; i++) { |
| weight_expanded = weight_expanded.unsqueeze(1); |
| } |
| weight_expanded = weight_expanded.expand_as(input); |
| |
| auto mask = positive_mask + nonpositive_mask * weight_expanded; |
| ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; |
| } |
| return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW}; |
| } |
| } |
| |
| // https://j-towns.github.io/papers/svd-derivative.pdf |
| Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| bool some, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { |
| auto m = self.size(0); |
| auto n = self.size(1); |
| auto k = sigma.size(0); |
| |
| Tensor u, v; |
| if (!some) { |
| // ignore the free subspace |
| u = raw_u.narrow(1, 0, k); |
| v = raw_v.narrow(1, 0, k); |
| } else { |
| u = raw_u; |
| v = raw_v; |
| } |
| |
| auto gu = grads[0]; |
| auto gsigma = grads[1]; |
| auto gv = grads[2]; |
| auto im = at::eye(self.type(), m); |
| auto in = at::eye(self.type(), n); |
| auto ut = u.t(); |
| auto vt = v.t(); |
| auto sigma_mat = sigma.diag(); |
| auto sigma_mat_inv = sigma.pow(-1).diag(); |
| auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat); |
| auto F = (sigma_expanded_sq - sigma_expanded_sq.t()).pow(-1); |
| auto& long_type = sigma.type().toScalarType(at::kLong); |
| auto diag_indices = at::arange(long_type, 0, F.numel(), k + 1); |
| F.view({-1}).index_fill_(0, diag_indices, 0); |
| |
| Tensor u_term, sigma_term, v_term; |
| |
| if (gu.defined()) { |
| u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat); |
| if (m > k) { |
| u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv); |
| } |
| u_term = u_term.mm(vt); |
| } else { |
| u_term = at::zeros(self.type(), {1}).expand_as(self); |
| } |
| |
| if (gsigma.defined()) { |
| sigma_term = u.mm(gsigma.diag()).mm(vt); |
| } else { |
| sigma_term = at::zeros(self.type(), {1}).expand_as(self); |
| } |
| |
| if (gv.defined()) { |
| auto gvt = gv.t(); |
| v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt); |
| if (n > k) { |
| v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt))); |
| } |
| v_term = u.mm(v_term); |
| } else { |
| v_term = at::zeros(self.type(), {1}).expand_as(self); |
| } |
| |
| return u_term + sigma_term + v_term; |
| } |
| |
| // Formula: |
| // d det / d A_ij = \sum_k (\prod_{l neq k} Sigma_l) U_ik V_jk |
| // that is, if det != 0 |
| // d det / d A = U * (Sigma / det) * V^T |
| Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self, |
| const Tensor& det, const Tensor& u, const Tensor& sigma, const Tensor& v) { |
| std::vector<torch::autograd::Variable> svd_grads(grads.begin() + 1, grads.end()); |
| auto svd_term = svd_backward(svd_grads, self, true, u, sigma, v); |
| |
| auto det_grad = grads[0]; |
| auto size = self.size(0); |
| auto null_dim = size - sigma.nonzero().size(0); |
| if (null_dim >= 2) { |
| // \prod_{l neq k} Sigma_l is zero every where |
| return svd_term; |
| } |
| if (null_dim == 1) { |
| // only last sigma is 0 |
| // \prod_{l neq k} Sigma_l is zero at all but last dim |
| // at last dim, it is: |
| auto scale = sigma.narrow(0, 0, size - 1).prod(); |
| auto last_u = u.narrow(1, size - 1, 1); |
| auto last_v = v.narrow(1, size - 1, 1); |
| return svd_term + last_u.mm(last_v.transpose(0, 1)).mul_(scale.mul_(det_grad)); |
| } |
| // no zero singular values |
| return svd_term + u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1)); |
| } |
| |
| // Reference: |
| // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf |
| // Sec. 2.3.1 Matrix inverse product |
| std::tuple<Tensor, Tensor> trtrs_backward( |
| const Tensor & grad_x, const Tensor & grad_m, |
| const Tensor & b, const Tensor & a, const Tensor & x, |
| const bool upper, const bool transpose, const bool unitriangular, |
| std::array<bool, 2> output_mask) { |
| Tensor grad_b, grad_a; |
| if (grad_x.defined()) { |
| grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular)); |
| if (output_mask[1]) { |
| grad_a = transpose ? -x.mm(grad_b.t()) : -grad_b.mm(x.t()); |
| if (upper) { |
| grad_a = grad_a.triu((int) unitriangular); |
| } else { |
| grad_a = grad_a.tril(-((int) unitriangular)); |
| } |
| } |
| } |
| if (!grad_a.defined()) { |
| grad_a = at::zeros(a.type(), {1}).expand_as(a); |
| } |
| if (!grad_b.defined()) { |
| grad_b = at::zeros(b.type(), {1}).expand_as(b); |
| } |
| if (output_mask[1] && grad_m.defined()) { |
| grad_a = grad_a.add(grad_m); |
| } |
| return std::tuple<Tensor, Tensor>{grad_b, grad_a}; |
| } |
| |
| Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) { |
| auto r = to_sum.sum(0, keepdim); |
| int64_t start_point_exclusive = keepdim ? 1 : 0; |
| for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) { |
| r = r.sum(dim, keepdim); |
| } |
| return r; |
| } |
| |
| // similar to expand_as below, but doesn't do the expand_as; operates as if |
| // reductions were done with keepdim=True |
| Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) { |
| auto src_expanded = src; |
| while (src_expanded.sizes().size() < target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(1); |
| } |
| if (src_expanded.sizes().size() == target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(0); |
| } |
| return src_expanded; |
| } |
| |
| // Helper for batchnorm_double_backward |
| // because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't |
| // do a straight expansion because it won't follow the broadcasting rules. |
| Tensor expand_as_dim1(const Tensor& src, const Tensor& target) { |
| auto src_expanded = src; |
| while (src_expanded.sizes().size() < target.sizes().size() - 1) { |
| src_expanded = src_expanded.unsqueeze(1); |
| } |
| return src_expanded.expand_as(target); |
| } |
| |
| // NB: This currently is PURPOSELY outside of the anonymous namespace, because |
| // we are manually calling it from some legacy batchnorm invocation code. Once |
| // that code moves into derivatives.yaml, this can be moved into the anonymous |
| // namespace, BUT NOT BEFORE THEN. |
| std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward( |
| const Tensor & input, |
| const Tensor & gamma, |
| const Tensor & ggI, |
| const Tensor & ggG, |
| const Tensor & ggB, |
| const Tensor & gO, |
| const Tensor & running_mean_v, |
| const Tensor & running_var_v, |
| bool training, |
| double eps, |
| const Tensor & save_mean_v, |
| const Tensor & save_std_v, |
| std::array<bool,3> output_mask) { |
| |
| // NB: In the original design of BatchNorm, save_mean, save_std, running_mean |
| // and running_var are unconditionally tensor "buffers", and never get wrapped |
| // in variables. However, when ATen happened, we never designed the API |
| // to allow mixed passing of tensors and variables (and this would be very |
| // confusing, because we always write "Tensor" in the signatures no matter if |
| // it's a Variable or a Tensor). So, when a user calls |
| // batchnorm_double_backward from Python (which still thinks that these are |
| // plain tensors), it goes ahead and wraps them in variables to appease |
| // the interface that only understand variables. Consequently, we have to |
| // unwrap them again. |
| const Tensor& save_mean = static_cast<const Variable&>(save_mean_v).data(); |
| const Tensor& save_std = static_cast<const Variable&>(save_std_v).data(); |
| const Tensor& running_mean = static_cast<const Variable&>(running_mean_v).data(); |
| const Tensor& running_var = static_cast<const Variable&>(running_var_v).data(); |
| |
| bool affine = gamma.defined(); |
| // TODO: Do we have a ScalarOrTensor type? Would such a thing exist? |
| Tensor gamma_expanded; |
| Tensor ggG_expanded, ggB_expanded; |
| if (affine) { |
| gamma_expanded = expand_as_dim1(gamma, input); |
| if (ggG.defined()) { |
| ggG_expanded = expand_as_dim1(ggG, input); |
| } |
| if (ggB.defined()) { |
| ggB_expanded = expand_as_dim1(ggB, input); |
| } |
| } else { |
| gamma_expanded = input.type().tensor({}).fill_(1); |
| } |
| |
| // define some terms we will reuse |
| auto M = input.size(0); |
| for (auto s : input.sizes().slice(2)) { |
| M *= s; |
| } |
| auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean, /*requires_grad=*/false), input); |
| auto input_sub_mu = input - mu; |
| auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), /*requires_grad=*/false), input); |
| auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2); |
| auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3); |
| |
| // calculate gI |
| auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2; |
| auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu); |
| auto gO_sum = sum_exclude_dim1(gO); |
| |
| Tensor gI; |
| if (ggI.defined() && training) { |
| auto ggI_sum = sum_exclude_dim1(ggI); |
| auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); |
| auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_( |
| (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); |
| auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); |
| auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); |
| auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI); |
| gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t)); |
| } |
| |
| // add contribution of gamma term to gI |
| Tensor gI_G_term; |
| if (affine && ggG.defined()) { |
| if (training) { |
| auto t0 = gO * sigma2_eps_neg_1_2; |
| auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M); |
| auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M); |
| gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2)); |
| gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; |
| } else { |
| gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO; |
| gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; |
| } |
| } |
| |
| // this is the first backward's grad_input |
| auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor { |
| auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M); |
| auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_( |
| input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu)); |
| return h0 * h1; |
| }; |
| |
| // calculate gG |
| Tensor gG; |
| if (affine && ggI.defined()) { |
| if (training) { |
| // gG is just the first backwards with the gamma term removed (then shaped properly) |
| gG = ggI * first_back_grad_input(gO, sigma2_eps_neg_1_2.type().tensor({}).fill_(1)); |
| gG = sum_exclude_dim1(gG, false); |
| } else { |
| gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false); |
| } |
| } |
| |
| // calculate ggO |
| Tensor ggO; |
| // contribution of input term |
| if (ggI.defined()) { |
| if (training) { |
| ggO = first_back_grad_input(ggI, gamma_expanded); |
| } else { |
| ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded; |
| } |
| } |
| if (ggG.defined()) { |
| auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2; |
| ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term; |
| } |
| if (ggB.defined()) { |
| auto ggO_B_term = ggB_expanded; |
| ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term; |
| } |
| |
| if (output_mask[0] && !ggO.defined()) ggO = at::zeros_like(gO); |
| if (output_mask[1] && !gG.defined()) { |
| AT_ASSERT(affine, "gamma should always be defined when it requires grad"); |
| gG = at::zeros_like(gamma); |
| } |
| if (output_mask[2] && !gI.defined()) gI = at::zeros_like(input); |
| |
| return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO}; |
| |
| } |
| |
| } // anonymous namespace |
| |
| ${autograd_function_definitions} |
| |
| }}} // namespace torch::autograd::generated |