blob: 739511779350ca704905f7176d8b5ff937982440 [file] [log] [blame]
# Defines derivative formulas and Python signatures of methods on Variable
#
# Each entry consists of:
# - A 'name', which specifies the ATen name of the function you
# are defining derivatives for, and an argument specification.
# - One or more gradients entries, mapping a differentiable input
# names to a formula specifying how to compute its gradient.
# Note that a single gradient entry can specify the gradient
# formula for multiple input names, by specifying a key
# "input1, input2" (see atan2 for an example).
#
# If a function has out-of-place and in-place variants, then the derivative
# definition for the in-place variant is optional. It will default to the
# definition for the out-of-place variant. Similarly, _out variants will
# default to the derivative for the non _out variant.
#
# Gradient expressions are standard C++ expressions operating on ATen
# variables. In a gradient expression, the following variables are in
# scope:
#
# - 'grad', the gradient of the output (often spelled grad_output
# in Python) which we are going to left-multiply.
#
# When a function returns multiple *differentiable* outputs,
# you can refer to the gradients of each outputs using 'grads',
# e.g., 'grads[0]', 'grads[1]'
#
# When a function returns *one* differentiable output (the
# first output) and some more nondifferentiable outputs,
# you MUST refer to the gradient of the differentiable output with
# 'grad' (this case is special-cased in our code generation).
#
# - Any of the input arguments, tensor or non-tensor, including
# argument names that only appear in Declarations.cwrap, e.g. 'output'.
#
# - 'result', representing the result of evaluating the forward
# expression for ATen native function declarations. If the forward
# expression outputs a tuple, use 'resultX' instead to access the
# X-th entry
#
# - 'grad_input_mask', a std::array<bool, n>, specifies which input
# gradients are actually needed. For example, in the entry
# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size
# two array, where `grad_input_mask[0]` is true if `input0` requires
# grad, and `grad_input_mask[1]` is true if `input1` requires grad.
#
# (NB: if your function computes gradient for a list of tensors,
# the `grad_input_mask` will only have a single entry for the list
# specifying if either zero or at least one tensor from the list requires
# grad. If we want to support more fine-grained signalling,
# we'll need some alternate variable which is not a std::array)
#
# - 'retain_variables', a bool which is true if a user has specified
# that saved variables should be retained in case the backwards is
# run again later. This allows an optimization where we can
# destroy saved buffers if we know variables are not going to be retained,
# e.g., it is used by _cudnn_rnn
#
# If you need a complex expression, e.g., with local variables,
# write a _backward function in tools/autograd/templates/Functions.cpp
# and invoke it from here. By the way, go read
# https://github.com/zdevito/ATen/issues/163; this describes an
# important hazard that occurs when porting backwards from Python to C++
#
# Double backwards gradient expressions can be somewhat confusing;
# the most important thing to remember is: (1) you need to define a
# derivative formula for every input, including inputs named things
# like 'grad_output', and (2) the gradient to multiply with is always
# called 'grad' (even though it really is a grad-grad).
#
# NB: There are a number of gradient definitions in here which are bogus
# (implemented using zeros_like). These gradients are (hopefully) not
# used by our frontend. You MUST check the frontend code; search for
# OpName.apply to see if it's still using a legacy Python style API.
#
# NB: The parameter names here MUST be consistent with the parameter names
# in ./torch/lib/ATen/Declarations.cwrap
- name: abs(Tensor self)
self: grad * self.sign()
- name: acos(Tensor self)
self: grad * -((-self * self + 1).rsqrt())
- name: add(Tensor self, Scalar other, *, Scalar alpha)
self: grad
- name: add(Tensor self, Tensor other, *, Scalar alpha)
self: grad
other: maybe_multiply(grad, alpha)
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha
batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
self: grad
tensor1: grad * value / tensor2
tensor2: -grad * value * tensor1 / (tensor2 * tensor2)
- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
self: grad
tensor1: grad * tensor2 * value
tensor2: grad * tensor1 * value
- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
- name: _addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
mat: grad.ger(vec) * alpha
vec: mat.t().mv(grad) * alpha
- name: _addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
vec1: grad.mv(vec2) * alpha
vec2: grad.t().mv(vec1) * alpha
- name: alias(Tensor self)
self: grad
- name: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
- name: asin(Tensor self)
self: grad * (-self * self + 1).rsqrt()
- name: atan(Tensor self)
self: grad / (self * self + 1)
- name: atan2(Tensor self, Tensor other)
self, other: atan2_backward(grad, self, other, grad_input_mask)
- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
self: maybe_multiply(grad, beta)
batch1: grad.bmm(batch2.transpose(1, 2)) * alpha
batch2: batch1.transpose(1, 2).bmm(grad) * alpha
- name: bernoulli(Tensor self, Generator generator)
self: zeros_like(grad)
- name: bmm(Tensor self, Tensor mat2)
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
- name: btrifact(Tensor self, bool pivot)
self: not_implemented("btrifact")
- name: btrifact_with_info(Tensor self, bool pivot)
self: not_implemented("btrifact_with_info")
- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
self: not_implemented("btrisolve")
- name: cat(TensorList tensors, int64_t dim)
tensors: cat_tensors_backward(grad, to_arg_sizes(tensors, dim), dim)
- name: cauchy_(Tensor self, double median, double sigma, Generator generator)
self: zeros_like(grad)
- name: ceil(Tensor self)
self: zeros_like(grad)
- name: clamp(Tensor self, Scalar min, Scalar max)
self: grad * (self > min).type_as(grad) * (self < max).type_as(grad)
- name: clamp_min(Tensor self, Scalar min)
self: grad * (self > min).type_as(grad)
- name: clamp_max(Tensor self, Scalar max)
self: grad * (self < max).type_as(grad)
- name: clone(Tensor self)
self: grad
- name: cos(Tensor self)
self: grad * -self.sin()
- name: cosh(Tensor self)
self: grad * self.sinh()
- name: cross(Tensor self, Tensor other, int64_t dim)
self: other.cross(grad, dim)
other: grad.cross(self, dim)
- name: cumprod(Tensor self, int64_t dim)
self: cumprod_backward(grad, self, dim)
- name: cumsum(Tensor self, int64_t dim)
self: cumsum_backward(grad, dim)
- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad)
self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad)
- name: _det_with_svd(Tensor self)
self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)
- name: diag(Tensor self, int64_t diagonal)
self: diag_backward(grad, self.sizes(), diagonal)
- name: dist(Tensor self, Tensor other, Scalar p)
self: norm_backward(grad, self - other, p, result)
other: -norm_backward(grad, self - other, p, result)
- name: div(Tensor self, Scalar other)
self: grad / other
- name: div(Tensor self, Tensor other)
self: grad / other
other: -grad * self / (other * other)
- name: dot(Tensor self, Tensor tensor)
self: grad * tensor
tensor: grad * self
- name: eig(Tensor self, bool eigenvectors)
self: not_implemented("eig")
- name: eq_(Tensor self, Scalar other)
self: zeros_like(self)
- name: eq_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: erf(Tensor self)
self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
- name: erfinv(Tensor self)
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
- name: exp(Tensor self)
self: grad * result
- name: expm1(Tensor self)
self: grad * (result + 1)
- name: expand(Tensor self, IntList size)
self: reduce_to(grad, self.sizes())
- name: exponential_(Tensor self, double lambd, Generator generator)
self: zeros_like(grad)
- name: fill_(Tensor self, Scalar value)
self: zeros_like(grad)
- name: fill_(Tensor self, Tensor value)
self: zeros_like(grad)
value: grad.sum()
- name: floor(Tensor self)
self: zeros_like(grad)
- name: fmod(Tensor self, Scalar other)
self: grad
- name: fmod(Tensor self, Tensor other)
self: grad
other: 'not_implemented("fmod: other")'
- name: frac(Tensor self)
self: grad
- name: gather(Tensor self, int64_t dim, Tensor index)
self: at::zeros(grad.type(), self.sizes()).scatter_add_(dim, index, grad)
- name: ge_(Tensor self, Scalar other)
self: zeros_like(self)
- name: ge_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: gels(Tensor self, Tensor A)
self: not_implemented("gels")
A: not_implemented("gels")
- name: geometric_(Tensor self, double p, Generator generator)
self: zeros_like(grad)
- name: geqrf(Tensor self)
self: not_implemented("geqrf")
- name: ger(Tensor self, Tensor vec2)
self: grad.mv(vec2)
vec2: grad.t().mv(self)
- name: gesv(Tensor self, Tensor A)
self: std::get<0>(gesv(grad, A.t()))
A: -at::mm(std::get<0>(gesv(grad, A.t())), solution.t())
- name: gt_(Tensor self, Scalar other)
self: zeros_like(self)
- name: gt_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
self: not_implemented("histc")
- name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
self: grad
source: grad.index_select(dim, index)
- name: index_copy_(Tensor self, int64_t dim, Tensor index, Tensor source)
self: grad.clone().index_fill_(dim, index, 0)
source: grad.index_select(dim, index)
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Scalar value)
self: grad.clone().index_fill_(dim, index, 0)
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Tensor value)
self: grad.clone().index_fill_(dim, index, 0)
value: grad.index_select(dim, index).sum()
- name: index_select(Tensor self, int64_t dim, Tensor index)
self: at::zeros(grad.type(), self.sizes()).index_add_(dim, index, grad)
- name: inverse(Tensor self)
self: -at::mm(output.t(), at::mm(grad, output.t()))
- name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim)
self: select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: le_(Tensor self, Scalar other)
self: zeros_like(self)
- name: le_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: lerp(Tensor self, Tensor end, Scalar weight)
self: grad * (1 - weight.toDouble())
end: grad * weight
- name: lgamma(Tensor self)
self: grad * digamma(self)
- name: digamma(Tensor self)
self: grad * polygamma(1, self)
- name: polygamma(int64_t n, Tensor self)
self: grad * polygamma(n + 1, self)
- name: log(Tensor self)
self: grad.div(self)
- name: log1p(Tensor self)
self: grad / (self + 1)
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
- name: lt_(Tensor self, Scalar other)
self: zeros_like(self)
- name: lt_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
self: grad.clone().masked_fill_(mask, 0)
- name: masked_fill_(Tensor self, Tensor mask, Tensor value)
self: grad.clone().masked_fill_(mask, 0)
value: at::where(mask, grad, zeros_like(grad)).sum()
- name: masked_scatter_(Tensor self, Tensor mask, Tensor source)
self: grad.clone().masked_fill_(mask, 0)
source: masked_scatter_backward(grad, mask, source.sizes())
- name: masked_select(Tensor self, Tensor mask)
self: zeros_like(self).masked_scatter_(mask, grad)
- name: max(Tensor self, int64_t dim, bool keepdim)
self: select_backward(grad, dim, max_indices, self.sizes(), keepdim)
- name: max(Tensor self)
self: select_backward_scalar(grad, self, result)
- name: max(Tensor self, Tensor other)
self: grad.clone().masked_fill_(self <= other, 0)
other: grad.clone().masked_fill_(self > other, 0)
- name: mean(Tensor self, int64_t dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
- name: mean(Tensor self)
self: grad.expand(self.sizes()) / self.numel()
- name: median(Tensor self)
self: select_backward_scalar(grad, self, result)
- name: median(Tensor self, int64_t dim, bool keepdim)
self: select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: min(Tensor self, int64_t dim, bool keepdim)
self: select_backward(grad, dim, min_indices, self.sizes(), keepdim)
- name: min(Tensor self)
self: select_backward_scalar(grad, self, result)
- name: min(Tensor self, Tensor other)
self: grad.clone().masked_fill_(self >= other, 0)
other: grad.clone().masked_fill_(self < other, 0)
- name: _mm(Tensor self, Tensor mat2)
self: mm_mat1_backward(grad, mat2, self.sizes(), self.strides(), 1)
mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1)
- name: mode(Tensor self, int64_t dim, bool keepdim)
self: select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: mul(Tensor self, Scalar other)
self: grad * other
- name: mul(Tensor self, Tensor other)
self: grad * other
other: grad * self
- name: mv(Tensor self, Tensor vec)
self: grad.ger(vec)
vec: self.t().mv(grad)
- name: ne_(Tensor self, Scalar other)
self: zeros_like(self)
- name: ne_(Tensor self, Tensor other)
self: zeros_like(self)
other: zeros_like(other)
- name: neg(Tensor self)
self: grad.neg()
- name: norm(Tensor self, Scalar p)
self: norm_backward(grad, self, p, result)
- name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim)
self: norm_backward(grad, self, p, result, dim, keepdim)
- name: normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
- name: normal(Tensor mean, double std, Generator generator)
mean: at::zeros(grad.type(), mean.sizes())
- name: normal(double mean, Tensor std, Generator generator)
std: at::zeros(grad.type(), std.sizes())
- name: normal(Tensor mean, Tensor std, Generator generator)
mean: at::zeros(grad.type(), mean.sizes())
std: at::zeros(grad.type(), std.sizes())
- name: orgqr(Tensor self, Tensor input2)
self: not_implemented("orgqr")
input2: not_implemented("orgqr")
- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left, bool transpose)
self: not_implemented("ormqr")
input2: not_implemented("ormqr")
input3: not_implemented("ormqr")
- name: permute(Tensor self, IntList dims)
self: permute_backwards(grad, dims)
- name: poisson(Tensor self, Generator generator)
self: zeros_like(self)
- name: potrf(Tensor self, bool upper)
self: potrf_backward(grad, upper, output)
- name: potri(Tensor self, bool upper)
self: not_implemented("potri")
- name: potrs(Tensor self, Tensor input2, bool upper)
self: not_implemented("potri")
input2: not_implemented("potri")
- name: pow(Tensor self, Scalar exponent)
self: grad * exponent * self.pow(exponent.toDouble() - 1)
- name: pow(Tensor self, Tensor exponent)
self: grad * exponent * self.pow(exponent - 1)
exponent: grad * self.pow(exponent) * self.log()
- name: prod(Tensor self, int64_t dim, bool keepdim)
self: prod_backward(grad, self, result, dim, keepdim)
- name: prod(Tensor self)
self: prod_backward(grad, self, result)
- name: pstrf(Tensor self, bool upper, Scalar tol)
self: not_implemented("pstrf")
- name: put_(Tensor self, Tensor index, Tensor source, bool accumulate)
self: grad.clone().put_(index, zeros_like(source), accumulate)
source: grad.take(index)
- name: qr(Tensor self)
self: not_implemented("qr")
- name: random_(Tensor self, int64_t from, int64_t to, Generator generator)
self: zeros_like(grad)
- name: random_(Tensor self, int64_t to, Generator generator)
self: zeros_like(grad)
- name: random_(Tensor self, Generator generator)
self: zeros_like(grad)
- name: reciprocal(Tensor self)
self: -grad * result * result
- name: remainder(Tensor self, Scalar other)
self: grad
- name: remainder(Tensor self, Tensor other)
self: grad
- name: renorm(Tensor self, Scalar p, int64_t dim, Scalar maxnorm)
self: renorm_backward(grad, self, p, dim, maxnorm)
- name: repeat(Tensor self, IntList repeats)
self: repeat_backward(grad, self.dim(), repeats)
- name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale)
input: RoiPooling2d_backward(input, rois, pooledHeight, pooledWidth, spatialScale, grad, result1)
- name: round(Tensor self)
self: zeros_like(grad)
- name: rsqrt(Tensor self)
self: -0.5 * grad * result.pow(3)
- name: scatter_(Tensor self, int64_t dim, Tensor index, Tensor src)
self: grad.clone().scatter_(dim, index, 0)
src: grad.gather(dim, index)
- name: scatter_(Tensor self, int64_t dim, Tensor index, Scalar value)
self: grad.clone().scatter_(dim, index, 0)
- name: scatter_add_(Tensor self, int64_t dim, Tensor index, Tensor src)
self: grad
src: grad.gather(dim, index)
- name: sigmoid(Tensor self)
self: _sigmoid_backward(grad, result)
- name: sign(Tensor self)
self: zeros_like(grad)
- name: sin(Tensor self)
self: grad * self.cos()
- name: sinh(Tensor self)
self: grad * self.cosh()
- name: sort(Tensor self, int64_t dim, bool descending)
self: select_backward(grad, dim, indices, self.sizes(), true)
- name: split(Tensor self, int64_t split_size, int64_t dim)
self: split_backward(grads, split_size, dim, self.sizes(), self.type())
- name: split_with_sizes(Tensor self, IntList split_sizes, int64_t dim)
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.type())
- name: sqrt(Tensor self)
self: grad / (2 * result)
- name: squeeze(Tensor self)
self: unsqueeze_to(grad, self.sizes());
- name: squeeze(Tensor self, int64_t dim)
self: unsqueeze_to(grad, dim, self.sizes())
- name: std(Tensor self, bool unbiased)
self: var_backward(grad / (result * 2), self, unbiased)
- name: std(Tensor self, int64_t dim, bool unbiased, bool keepdim)
self: var_backward(grad / (result * 2), self, dim, unbiased, keepdim)
- name: sub(Tensor self, Scalar other, *, Scalar alpha)
self: grad
- name: sub(Tensor self, Tensor other, *, Scalar alpha)
self: grad
other: -grad * alpha
- name: sum(Tensor self)
self: grad.expand(self.sizes())
- name: sum(Tensor self, int64_t dim, bool keepdim)
self: sum_backward(grad, self.sizes(), dim, keepdim)
- name: svd(Tensor self, bool some)
self: svd_backward(grads, self, some, res1, res2, res3)
- name: symeig(Tensor self, bool eigenvectors, bool upper)
self: not_implemented("symeig")
- name: t(Tensor self)
self: grad.t()
- name: take(Tensor self, Tensor index)
self: zeros_like(self).put_(index, grad, true)
- name: tan(Tensor self)
self: grad * (1 + result.pow(2))
- name: tanh(Tensor self)
self: _tanh_backward(grad, result)
- name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted)
self: select_backward(grad, dim, indices, self.sizes(), true)
- name: trace(Tensor self)
self: trace_backward(grad, self.sizes())
- name: transpose(Tensor self, int64_t dim0, int64_t dim1)
self: grad.transpose(dim0, dim1)
- name: tril(Tensor self, int64_t diagonal)
self: grad.tril(diagonal)
- name: triu(Tensor self, int64_t diagonal)
self: grad.triu(diagonal)
- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
self, A: trtrs_backward(grads[0], grads[1], self, A, res1, upper, transpose, unitriangular, grad_input_mask)
- name: trunc(Tensor self)
self: zeros_like(grad)
- name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step)
self: unfold_backward(grad, self.sizes(), dimension, size, step)
- name: uniform_(Tensor self, double from, double to, Generator generator)
self: zeros_like(grad)
- name: _unique(Tensor self, bool sorted, bool return_inverse)
self: not_implemented("_unique")
- name: _unsafe_view(Tensor self, IntList size)
self: grad.contiguous().view(self.sizes())
- name: unsqueeze(Tensor self, int64_t dim)
self: grad.squeeze(dim)
- name: var(Tensor self, bool unbiased)
self: var_backward(grad, self, unbiased)
- name: var(Tensor self, int64_t dim, bool unbiased, bool keepdim)
self: var_backward(grad, self, dim, unbiased, keepdim)
- name: view(Tensor self, IntList size)
self: grad.contiguous().view(self.sizes())
- name: _s_where(Tensor condition, Tensor self, Tensor other)
self: where(condition, grad, zeros_like(grad))
other: where(condition, zeros_like(grad), grad)
- name: zero_(Tensor self)
self: zeros_like(grad)
- name: _sparse_mask(Tensor self, SparseTensor mask)
self: not_implemented("_sparse_mask")
mask: not_implemented("_sparse_mask")
- name: _standard_gamma(Tensor self, Generator generator)
self: grad * self._standard_gamma_grad(output)
- name: _standard_gamma_grad(Tensor self, Tensor output)
self: not_implemented("_standard_gamma_grad")
# NN
- name: binary_cross_entropy_forward(Tensor self, Tensor target, Tensor weight, bool size_average, bool reduce)
self: binary_cross_entropy_backward(grad, self, target, weight, size_average, reduce)
- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
- name: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
weight: embedding_bag_backward(grad, indices, offsets, result1, result2, weight.size(0), scale_grad_by_freq, mode, sparse)
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
self: not_implemented("embedding_renorm")
- name: kl_div_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: kl_div_backward(grad, self, target, size_average, reduce)
- name: l1_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: l1_loss_backward(grad, self, target, size_average, reduce)
- name: mse_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: mse_loss_backward(grad, self, target, size_average, reduce)
- name: multi_margin_loss_forward(Tensor self, Tensor target, Scalar p, Scalar margin, Tensor weight, bool size_average, bool reduce)
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, size_average, reduce)
- name: multilabel_margin_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: multilabel_margin_loss_backward(grad, self, target, size_average, reduce, is_target)
- name: nll_loss_forward(Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce)
self: nll_loss_backward(grad, self, target, weight, size_average, ignore_index, reduce, total_weight)
- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce)
self: nll_loss2d_backward(grad, self, target, weight, size_average, ignore_index, reduce, total_weight)
- name: smooth_l1_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: smooth_l1_loss_backward(grad, self, target, size_average, reduce)
- name: soft_margin_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce)
self: soft_margin_loss_backward(grad, self, target, size_average, reduce)
- name: elu_forward(Tensor self, Scalar alpha, Scalar scale)
self: elu_backward(grad, alpha, scale, output)
- name: glu_forward(Tensor self, int64_t dim)
self: glu_backward(grad, self, dim)
- name: hardshrink_forward(Tensor self, Scalar lambd)
self: hardshrink_backward(grad, self, lambd)
- name: hardtanh_forward(Tensor self, Scalar min_val, Scalar max_val)
self: hardtanh_backward(grad, self, min_val, max_val)
- name: hardtanh_forward_(Tensor self, Scalar min_val, Scalar max_val)
self: hardtanh_backward(grad, output, min_val, max_val)
- name: leaky_relu_forward(Tensor self, Scalar negative_slope)
self: leaky_relu_backward(grad, self, negative_slope)
- name: leaky_relu_forward_(Tensor self, Scalar negative_slope)
self: leaky_relu_backward(grad, output, negative_slope)
- name: log_sigmoid_forward(Tensor self)
self: log_sigmoid_backward(grad, self, buffer)
- name: log_softmax_forward(Tensor self, int64_t dim)
self: log_softmax_backward(grad, self, dim, output)
- name: prelu_forward(Tensor self, Tensor weight)
self, weight: prelu_backward(grad, self, weight, grad_input_mask)
- name: rrelu_with_noise_forward(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
- name: rrelu_with_noise_forward_(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_with_noise_backward(grad, output, noise, lower, upper, training)
- name: softmax_forward(Tensor self, int64_t dim)
self: softmax_backward(grad, self, dim, output)
- name: softplus_forward(Tensor self, Scalar beta, Scalar threshold)
self: softplus_backward(grad, self, beta, threshold, output)
- name: softshrink_forward(Tensor self, Scalar lambd)
self: softshrink_backward(grad, self, lambd)
- name: threshold_forward(Tensor self, Scalar threshold, Scalar value)
self: threshold_backward(grad, self, threshold, value)
- name: threshold_forward_(Tensor self, Scalar threshold, Scalar value)
self: threshold_backward(grad, output, threshold, value)
- name: reflection_pad1d_forward(Tensor self, IntList padding)
self: reflection_pad1d_backward(grad, self, padding)
- name: reflection_pad2d_forward(Tensor self, IntList padding)
self: reflection_pad2d_backward(grad, self, padding)
- name: replication_pad1d_forward(Tensor self, IntList padding)
self: replication_pad1d_backward(grad, self, padding)
- name: replication_pad2d_forward(Tensor self, IntList padding)
self: replication_pad2d_backward(grad, self, padding)
- name: replication_pad3d_forward(Tensor self, IntList padding)
self: replication_pad3d_backward(grad, self, padding)
- name: upsample_linear1d_forward(Tensor self, IntList output_size)
self: upsample_linear1d_backward(grad, output_size, self.sizes())
- name: upsample_bilinear2d_forward(Tensor self, IntList output_size)
self: upsample_bilinear2d_backward(grad, output_size, self.sizes())
- name: upsample_trilinear3d_forward(Tensor self, IntList output_size)
self: upsample_trilinear3d_backward(grad, output_size, self.sizes())
- name: upsample_nearest1d_forward(Tensor self, int64_t scale_factor)
self: upsample_nearest1d_backward(grad, self, scale_factor)
- name: upsample_nearest2d_forward(Tensor self, int64_t scale_factor)
self: upsample_nearest2d_backward(grad, self, scale_factor)
- name: upsample_nearest3d_forward(Tensor self, int64_t scale_factor)
self: upsample_nearest3d_backward(grad, self, scale_factor)
- name: adaptive_avg_pool2d_forward(Tensor self, IntList output_size)
self: adaptive_avg_pool2d_backward(grad, self)
- name: adaptive_avg_pool3d_forward(Tensor self, IntList output_size)
self: adaptive_avg_pool3d_backward(grad, self)
- name: adaptive_max_pool2d_forward(Tensor self, IntList output_size)
self: adaptive_max_pool2d_backward(grad, self, indices)
- name: adaptive_max_pool3d_forward(Tensor self, IntList output_size)
self: adaptive_max_pool3d_backward(grad, self, indices)
- name: avg_pool2d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad)
- name: avg_pool3d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad)
- name: fractional_max_pool2d_forward(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples)
self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, indices)
- name: max_pool2d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode)
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
- name: max_pool3d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode)
self: max_pool3d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
- name: max_unpool2d_forward(Tensor self, Tensor indices, IntList output_size)
self: max_unpool2d_backward(grad, self, indices, output_size)
- name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding)
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
- name: thnn_batch_norm_forward(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
self, weight, bias: thnn_batch_norm_backward(grad.contiguous(), self, weight, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
- name: thnn_batch_norm_backward(Tensor grad_output, Tensor self, Tensor weight, Tensor running_mean, Tensor running_var, bool training, double eps, Tensor save_mean, Tensor save_std, std::array<bool,3> output_mask)
self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
save_mean: not_implemented("thnn_batch_norm_backward save_mean")
save_std: not_implemented("thnn_batch_norm_backward save_std")
- name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask)
- name: thnn_conv_transpose2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
- name: thnn_conv_transpose3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
self, weight, bias: thnn_conv_transpose3d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, finput, fgrad_input, grad_input_mask)
- name: thnn_conv_transpose3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
- name: thnn_conv2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding)
self, weight, bias: thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask)
- name: thnn_conv2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
- name: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
self, weight: thnn_conv_depthwise2d_backward(grad.contiguous(), self, weight, kernel_size, stride, padding, dilation, grad_input_mask)
bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1)
- name: thnn_conv_depthwise2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, std::array<bool,2> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, grad_input_mask)
- name: thnn_conv3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding)
self, weight, bias: thnn_conv3d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask)
- name: thnn_conv3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
- name: thnn_conv_dilated2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
self, weight, bias: thnn_conv_dilated2d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask)
- name: thnn_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
- name: thnn_conv_dilated3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
self, weight, bias: thnn_conv_dilated3d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask)
- name: thnn_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
# NN double backwards support
- name: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self)
grad_output: adaptive_avg_pool2d(grad, { grad_output.size(-2), grad_output.size(-1) })
self: zeros_like(self)
- name: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self)
grad_output: adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) })
self: zeros_like(self)
- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices)
grad_output: max_pool_double_backward(grad, indices, 2)
self: zeros_like(self)
- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices)
grad_output: max_pool_double_backward(grad, indices, 3)
self: zeros_like(self)
- name: avg_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad)
self: zeros_like(self)
- name: avg_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad)
self: zeros_like(self)
- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Tensor output)
grad_output: elu_backward(grad, alpha, scale, output)
output: grad * grad_output * (output < 0).toType(grad.type())
- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList output_size, Tensor indices)
grad_output: max_pool_double_backward(grad, indices, 2)
self: zeros_like(self)
- name: glu_backward(Tensor grad_output, Tensor self, int64_t dim)
grad_output: glu_double_backward_grad_output(grad, self, dim)
self: glu_double_backward(grad, grad_output, self, dim)
- name: hardshrink_backward(Tensor grad_output, Tensor self, Scalar lambd)
grad_output: hardshrink_backward(grad, self, lambd)
self: zeros_like(grad)
- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val)
grad_output: hardtanh_backward(grad, self, min_val, max_val)
self: zeros_like(grad)
- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce)
grad_output: kl_div_double_backward_grad_output(grad, self, target, size_average, reduce)
self: zeros_like(grad)
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce)
grad_output: l1_loss_double_backward_grad_output(grad, self, target, size_average, reduce)
self: zeros_like(grad)
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer)
grad_output: log_sigmoid_backward(grad, self, buffer)
self: log_sigmoid_double_backward(grad * grad_output, self)
- name: log_softmax_backward(Tensor grad_output, Tensor self, int64_t dim, Tensor output)
grad_output: grad - (grad * output.exp()).sum(dim, true)
self: log_softmax_double_backward(grad, grad_output, dim, output)
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope)
grad_output: leaky_relu_backward(grad, self, negative_slope)
self: zeros_like(grad)
- name: max_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices)
grad_output: max_pool_double_backward(grad, indices, 2);
self: zeros_like(self)
- name: max_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices)
grad_output: max_pool_double_backward(grad, indices, 3);
self: zeros_like(self)
- name: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, IntList output_size)
grad_output: max_unpool2d(grad, indices, output_size)
self: zeros_like(self)
- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce)
grad_output: mse_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce)
self: mse_loss_double_backward(grad * grad_output, self, size_average, reduce)
- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce, Tensor total_weight)
grad_output: nll_loss(grad, target, weight, size_average, ignore_index, reduce)
self: zeros_like(grad)
- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce, Tensor total_weight)
grad_output: nll_loss2d(grad, target, weight, size_average, ignore_index, reduce)
self: zeros_like(grad)
- name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight, std::array<bool,2> output_mask)
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight, grad_input_mask)
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training)
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
self: zeros_like(grad)
- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, IntList padding)
grad_output: reflection_pad1d(grad, padding)
self: zeros_like(self)
- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, IntList padding)
grad_output: reflection_pad2d(grad, padding)
self: zeros_like(self)
- name: replication_pad1d_backward(Tensor grad_output, Tensor self, IntList padding)
grad_output: replication_pad1d(grad, padding)
self: zeros_like(self)
- name: replication_pad2d_backward(Tensor grad_output, Tensor self, IntList padding)
grad_output: replication_pad2d(grad, padding)
self: zeros_like(self)
- name: replication_pad3d_backward(Tensor grad_output, Tensor self, IntList padding)
grad_output: replication_pad3d(grad, padding)
self: zeros_like(self)
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce)
grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce)
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, size_average, reduce)
- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output)
grad_output: softplus_backward(grad, self, beta, threshold, output)
self: softplus_double_backward(grad * grad_output, self, beta, threshold)
- name: softmax_backward(Tensor grad_output, Tensor self, int64_t dim, Tensor output)
grad_output: softmax_backward(grad, self, dim, output)
self: softmax_double_backward(grad, grad_output, dim, output)
- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce)
grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce)
self: soft_margin_loss_double_backward(grad * grad_output, self, target, size_average, reduce)
- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd)
grad_output: softshrink_backward(grad, self, lambd)
self: zeros_like(grad)
- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold, Scalar value)
grad_output: threshold_backward(grad, self, threshold, value)
self: zeros_like(grad)
- name: upsample_linear1d_backward(Tensor grad_output, IntList output_size, IntList input_size)
grad_output: upsample_linear1d(grad, output_size)
- name: upsample_bilinear2d_backward(Tensor grad_output, IntList output_size, IntList input_size)
grad_output: upsample_bilinear2d(grad, output_size)
- name: upsample_trilinear3d_backward(Tensor grad_output, IntList output_size, IntList input_size)
grad_output: upsample_trilinear3d(grad, output_size)
- name: upsample_nearest1d_backward(Tensor grad_output, Tensor self, int64_t scale_factor)
grad_output: upsample_nearest1d(grad, scale_factor)
self: zeros_like(grad)
- name: upsample_nearest2d_backward(Tensor grad_output, Tensor self, int64_t scale_factor)
grad_output: upsample_nearest2d(grad, scale_factor)
self: zeros_like(grad)
- name: upsample_nearest3d_backward(Tensor grad_output, Tensor self, int64_t scale_factor)
grad_output: upsample_nearest3d(grad, scale_factor)
self: zeros_like(grad)
- name: _sigmoid_backward(Tensor grad_output, Tensor output)
grad_output: _sigmoid_backward(grad, output)
output: grad * grad_output * (-2 * output + 1)
- name: _tanh_backward(Tensor grad_output, Tensor output)
grad_output: _tanh_backward(grad, output)
output: -2 * output * grad * grad_output
# cudnn
- name: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
self, weight, bias: cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask)
- name: cudnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
self, weight, bias: cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
# The above backward definitions are equivalent to the definitions below. Why do we bundle
# everything up? It's because it's more convenient to define double backwards
# when there is a single function that manages everything.
#
# Unfortuantely, there's one downside to not doing it all in one day: we
# unconditionally save input and weight, even if weight/input gradients are not
# being computed. That's too bad.
#
# input: cudnn_convolution_backward_input(input.sizes(), grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic)
# weight: cudnn_convolution_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic)
# bias: cudnn_convolution_backward_bias(grad.contiguous())
#
# input: cudnn_convolution_transpose_backward_input(grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic)
# weight: cudnn_convolution_transpose_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic)
# bias: cudnn_convolution_backward_bias(grad.contiguous())
- name: cudnn_grid_sampler(Tensor self, Tensor grid)
self, grid: cudnn_grid_sampler_backward(self, grid, grad)
- name: cudnn_affine_grid_generator(Tensor theta, int64_t N, int64_t C, int64_t H, int64_t W)
theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W)
# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute
# backward in evaluation mode, because the math for backward in evaluation mode
# is different (since the forward math is different), and CuDNN does not support
# it. And in any case, you shouldn't be using this bn in evaluation mode,
# because it should be merged into the previous convolution (left for future
# work.)
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)"
# HACK: save_mean and save_var are going to be passed in as
# requires_grad variables (even though we'll never backprop through
# them) so we need to prevent the unpacking from triggering an error.
- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
save_mean: not_implemented("cudnn_batch_norm_backward save_mean")
save_var: not_implemented("cudnn_batch_norm_backward save_var")
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
# nnpack
- name: nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH)
input: nnpack_spatial_convolution_backward_input(input, grad, weight, kW, kH, padW, padH)
weight: nnpack_spatial_convolution_backward_weight(input, weight.sizes(), grad, kW, kH, padW, padH)
bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1)
- name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, Tensor dropout_state)
input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"