| # Defines derivative formulas and Python signatures of methods on Variable |
| # |
| # Each entry consists of: |
| # - A 'name', which specifies the ATen name of the function you |
| # are defining derivatives for, and an argument specification. |
| # - One or more gradients entries, mapping a differentiable input |
| # names to a formula specifying how to compute its gradient. |
| # Note that a single gradient entry can specify the gradient |
| # formula for multiple input names, by specifying a key |
| # "input1, input2" (see atan2 for an example). |
| # |
| # If a function has out-of-place and in-place variants, then the derivative |
| # definition for the in-place variant is optional. It will default to the |
| # definition for the out-of-place variant. Similarly, _out variants will |
| # default to the derivative for the non _out variant. |
| # |
| # Gradient expressions are standard C++ expressions operating on ATen |
| # variables. In a gradient expression, the following variables are in |
| # scope: |
| # |
| # - 'grad', the gradient of the output (often spelled grad_output |
| # in Python) which we are going to left-multiply. |
| # |
| # When a function returns multiple *differentiable* outputs, |
| # you can refer to the gradients of each outputs using 'grads', |
| # e.g., 'grads[0]', 'grads[1]' |
| # |
| # When a function returns *one* differentiable output (the |
| # first output) and some more nondifferentiable outputs, |
| # you MUST refer to the gradient of the differentiable output with |
| # 'grad' (this case is special-cased in our code generation). |
| # |
| # - Any of the input arguments, tensor or non-tensor, including |
| # argument names that only appear in Declarations.cwrap, e.g. 'output'. |
| # |
| # - 'result', representing the result of evaluating the forward |
| # expression for ATen native function declarations. If the forward |
| # expression outputs a tuple, use 'resultX' instead to access the |
| # X-th entry |
| # |
| # - 'grad_input_mask', a std::array<bool, n>, specifies which input |
| # gradients are actually needed. For example, in the entry |
| # `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size |
| # two array, where `grad_input_mask[0]` is true if `input0` requires |
| # grad, and `grad_input_mask[1]` is true if `input1` requires grad. |
| # |
| # (NB: if your function computes gradient for a list of tensors, |
| # the `grad_input_mask` will only have a single entry for the list |
| # specifying if either zero or at least one tensor from the list requires |
| # grad. If we want to support more fine-grained signalling, |
| # we'll need some alternate variable which is not a std::array) |
| # |
| # - 'retain_variables', a bool which is true if a user has specified |
| # that saved variables should be retained in case the backwards is |
| # run again later. This allows an optimization where we can |
| # destroy saved buffers if we know variables are not going to be retained, |
| # e.g., it is used by _cudnn_rnn |
| # |
| # If you need a complex expression, e.g., with local variables, |
| # write a _backward function in tools/autograd/templates/Functions.cpp |
| # and invoke it from here. By the way, go read |
| # https://github.com/zdevito/ATen/issues/163; this describes an |
| # important hazard that occurs when porting backwards from Python to C++ |
| # |
| # Double backwards gradient expressions can be somewhat confusing; |
| # the most important thing to remember is: (1) you need to define a |
| # derivative formula for every input, including inputs named things |
| # like 'grad_output', and (2) the gradient to multiply with is always |
| # called 'grad' (even though it really is a grad-grad). |
| # |
| # NB: There are a number of gradient definitions in here which are bogus |
| # (implemented using zeros_like). These gradients are (hopefully) not |
| # used by our frontend. You MUST check the frontend code; search for |
| # OpName.apply to see if it's still using a legacy Python style API. |
| # |
| # NB: The parameter names here MUST be consistent with the parameter names |
| # in ./torch/lib/ATen/Declarations.cwrap |
| - name: abs(Tensor self) |
| self: grad * self.sign() |
| |
| - name: acos(Tensor self) |
| self: grad * -((-self * self + 1).rsqrt()) |
| |
| - name: add(Tensor self, Scalar other, *, Scalar alpha) |
| self: grad |
| |
| - name: add(Tensor self, Tensor other, *, Scalar alpha) |
| self: grad |
| other: maybe_multiply(grad, alpha) |
| |
| - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) |
| self: maybe_multiply(grad, beta) |
| batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha |
| batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha |
| |
| - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) |
| self: grad |
| tensor1: grad * value / tensor2 |
| tensor2: -grad * value * tensor1 / (tensor2 * tensor2) |
| |
| - name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) |
| self: grad |
| tensor1: grad * tensor2 * value |
| tensor2: grad * tensor1 * value |
| |
| - name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) |
| self: maybe_multiply(grad, beta) |
| mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha) |
| mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha) |
| |
| - name: _addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) |
| self: maybe_multiply(grad, beta) |
| mat: grad.ger(vec) * alpha |
| vec: mat.t().mv(grad) * alpha |
| |
| - name: _addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) |
| self: maybe_multiply(grad, beta) |
| vec1: grad.mv(vec2) * alpha |
| vec2: grad.t().mv(vec1) * alpha |
| |
| - name: alias(Tensor self) |
| self: grad |
| |
| - name: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset) |
| self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) |
| |
| - name: asin(Tensor self) |
| self: grad * (-self * self + 1).rsqrt() |
| |
| - name: atan(Tensor self) |
| self: grad / (self * self + 1) |
| |
| - name: atan2(Tensor self, Tensor other) |
| self, other: atan2_backward(grad, self, other, grad_input_mask) |
| |
| - name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) |
| self: maybe_multiply(grad, beta) |
| batch1: grad.bmm(batch2.transpose(1, 2)) * alpha |
| batch2: batch1.transpose(1, 2).bmm(grad) * alpha |
| |
| - name: bernoulli(Tensor self, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: bmm(Tensor self, Tensor mat2) |
| self: grad.bmm(mat2.transpose(1, 2)) |
| mat2: self.transpose(1, 2).bmm(grad) |
| |
| - name: btrifact(Tensor self, bool pivot) |
| self: not_implemented("btrifact") |
| |
| - name: btrifact_with_info(Tensor self, bool pivot) |
| self: not_implemented("btrifact_with_info") |
| |
| - name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) |
| self: not_implemented("btrisolve") |
| |
| - name: cat(TensorList tensors, int64_t dim) |
| tensors: cat_tensors_backward(grad, to_arg_sizes(tensors, dim), dim) |
| |
| - name: cauchy_(Tensor self, double median, double sigma, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: ceil(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: clamp(Tensor self, Scalar min, Scalar max) |
| self: grad * (self > min).type_as(grad) * (self < max).type_as(grad) |
| |
| - name: clamp_min(Tensor self, Scalar min) |
| self: grad * (self > min).type_as(grad) |
| |
| - name: clamp_max(Tensor self, Scalar max) |
| self: grad * (self < max).type_as(grad) |
| |
| - name: clone(Tensor self) |
| self: grad |
| |
| - name: cos(Tensor self) |
| self: grad * -self.sin() |
| |
| - name: cosh(Tensor self) |
| self: grad * self.sinh() |
| |
| - name: cross(Tensor self, Tensor other, int64_t dim) |
| self: other.cross(grad, dim) |
| other: grad.cross(self, dim) |
| |
| - name: cumprod(Tensor self, int64_t dim) |
| self: cumprod_backward(grad, self, dim) |
| |
| - name: cumsum(Tensor self, int64_t dim) |
| self: cumsum_backward(grad, dim) |
| |
| - name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad) |
| self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad) |
| |
| - name: _det_with_svd(Tensor self) |
| self: _det_with_svd_backward(grads, self, result0, result1, result2, result3) |
| |
| - name: diag(Tensor self, int64_t diagonal) |
| self: diag_backward(grad, self.sizes(), diagonal) |
| |
| - name: dist(Tensor self, Tensor other, Scalar p) |
| self: norm_backward(grad, self - other, p, result) |
| other: -norm_backward(grad, self - other, p, result) |
| |
| - name: div(Tensor self, Scalar other) |
| self: grad / other |
| |
| - name: div(Tensor self, Tensor other) |
| self: grad / other |
| other: -grad * self / (other * other) |
| |
| - name: dot(Tensor self, Tensor tensor) |
| self: grad * tensor |
| tensor: grad * self |
| |
| - name: eig(Tensor self, bool eigenvectors) |
| self: not_implemented("eig") |
| |
| - name: eq_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: eq_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: erf(Tensor self) |
| self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad |
| |
| - name: erfinv(Tensor self) |
| self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad |
| |
| - name: exp(Tensor self) |
| self: grad * result |
| |
| - name: expm1(Tensor self) |
| self: grad * (result + 1) |
| |
| - name: expand(Tensor self, IntList size) |
| self: reduce_to(grad, self.sizes()) |
| |
| - name: exponential_(Tensor self, double lambd, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: fill_(Tensor self, Scalar value) |
| self: zeros_like(grad) |
| |
| - name: fill_(Tensor self, Tensor value) |
| self: zeros_like(grad) |
| value: grad.sum() |
| |
| - name: floor(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: fmod(Tensor self, Scalar other) |
| self: grad |
| |
| - name: fmod(Tensor self, Tensor other) |
| self: grad |
| other: 'not_implemented("fmod: other")' |
| |
| - name: frac(Tensor self) |
| self: grad |
| |
| - name: gather(Tensor self, int64_t dim, Tensor index) |
| self: at::zeros(grad.type(), self.sizes()).scatter_add_(dim, index, grad) |
| |
| - name: ge_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: ge_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: gels(Tensor self, Tensor A) |
| self: not_implemented("gels") |
| A: not_implemented("gels") |
| |
| - name: geometric_(Tensor self, double p, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: geqrf(Tensor self) |
| self: not_implemented("geqrf") |
| |
| - name: ger(Tensor self, Tensor vec2) |
| self: grad.mv(vec2) |
| vec2: grad.t().mv(self) |
| |
| - name: gesv(Tensor self, Tensor A) |
| self: std::get<0>(gesv(grad, A.t())) |
| A: -at::mm(std::get<0>(gesv(grad, A.t())), solution.t()) |
| |
| - name: gt_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: gt_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: histc(Tensor self, int64_t bins, Scalar min, Scalar max) |
| self: not_implemented("histc") |
| |
| - name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source) |
| self: grad |
| source: grad.index_select(dim, index) |
| |
| - name: index_copy_(Tensor self, int64_t dim, Tensor index, Tensor source) |
| self: grad.clone().index_fill_(dim, index, 0) |
| source: grad.index_select(dim, index) |
| |
| - name: index_fill_(Tensor self, int64_t dim, Tensor index, Scalar value) |
| self: grad.clone().index_fill_(dim, index, 0) |
| |
| - name: index_fill_(Tensor self, int64_t dim, Tensor index, Tensor value) |
| self: grad.clone().index_fill_(dim, index, 0) |
| value: grad.index_select(dim, index).sum() |
| |
| - name: index_select(Tensor self, int64_t dim, Tensor index) |
| self: at::zeros(grad.type(), self.sizes()).index_add_(dim, index, grad) |
| |
| - name: inverse(Tensor self) |
| self: -at::mm(output.t(), at::mm(grad, output.t())) |
| |
| - name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim) |
| self: select_backward(grad, dim, indices, self.sizes(), keepdim) |
| |
| - name: le_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: le_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: lerp(Tensor self, Tensor end, Scalar weight) |
| self: grad * (1 - weight.toDouble()) |
| end: grad * weight |
| |
| - name: lgamma(Tensor self) |
| self: grad * digamma(self) |
| |
| - name: digamma(Tensor self) |
| self: grad * polygamma(1, self) |
| |
| - name: polygamma(int64_t n, Tensor self) |
| self: grad * polygamma(n + 1, self) |
| |
| - name: log(Tensor self) |
| self: grad.div(self) |
| |
| - name: log1p(Tensor self) |
| self: grad / (self + 1) |
| |
| - name: log_normal_(Tensor self, double mean, double std, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: lt_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: lt_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: masked_fill_(Tensor self, Tensor mask, Scalar value) |
| self: grad.clone().masked_fill_(mask, 0) |
| |
| - name: masked_fill_(Tensor self, Tensor mask, Tensor value) |
| self: grad.clone().masked_fill_(mask, 0) |
| value: at::where(mask, grad, zeros_like(grad)).sum() |
| |
| - name: masked_scatter_(Tensor self, Tensor mask, Tensor source) |
| self: grad.clone().masked_fill_(mask, 0) |
| source: masked_scatter_backward(grad, mask, source.sizes()) |
| |
| - name: masked_select(Tensor self, Tensor mask) |
| self: zeros_like(self).masked_scatter_(mask, grad) |
| |
| - name: max(Tensor self, int64_t dim, bool keepdim) |
| self: select_backward(grad, dim, max_indices, self.sizes(), keepdim) |
| |
| - name: max(Tensor self) |
| self: select_backward_scalar(grad, self, result) |
| |
| - name: max(Tensor self, Tensor other) |
| self: grad.clone().masked_fill_(self <= other, 0) |
| other: grad.clone().masked_fill_(self > other, 0) |
| |
| - name: mean(Tensor self, int64_t dim, bool keepdim) |
| self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim) |
| |
| - name: mean(Tensor self) |
| self: grad.expand(self.sizes()) / self.numel() |
| |
| - name: median(Tensor self) |
| self: select_backward_scalar(grad, self, result) |
| |
| - name: median(Tensor self, int64_t dim, bool keepdim) |
| self: select_backward(grad, dim, indices, self.sizes(), keepdim) |
| |
| - name: min(Tensor self, int64_t dim, bool keepdim) |
| self: select_backward(grad, dim, min_indices, self.sizes(), keepdim) |
| |
| - name: min(Tensor self) |
| self: select_backward_scalar(grad, self, result) |
| |
| - name: min(Tensor self, Tensor other) |
| self: grad.clone().masked_fill_(self >= other, 0) |
| other: grad.clone().masked_fill_(self < other, 0) |
| |
| - name: _mm(Tensor self, Tensor mat2) |
| self: mm_mat1_backward(grad, mat2, self.sizes(), self.strides(), 1) |
| mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1) |
| |
| - name: mode(Tensor self, int64_t dim, bool keepdim) |
| self: select_backward(grad, dim, indices, self.sizes(), keepdim) |
| |
| - name: mul(Tensor self, Scalar other) |
| self: grad * other |
| |
| - name: mul(Tensor self, Tensor other) |
| self: grad * other |
| other: grad * self |
| |
| - name: mv(Tensor self, Tensor vec) |
| self: grad.ger(vec) |
| vec: self.t().mv(grad) |
| |
| - name: ne_(Tensor self, Scalar other) |
| self: zeros_like(self) |
| |
| - name: ne_(Tensor self, Tensor other) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| |
| - name: neg(Tensor self) |
| self: grad.neg() |
| |
| - name: norm(Tensor self, Scalar p) |
| self: norm_backward(grad, self, p, result) |
| |
| - name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim) |
| self: norm_backward(grad, self, p, result, dim, keepdim) |
| |
| - name: normal_(Tensor self, double mean, double std, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: normal(Tensor mean, double std, Generator generator) |
| mean: at::zeros(grad.type(), mean.sizes()) |
| |
| - name: normal(double mean, Tensor std, Generator generator) |
| std: at::zeros(grad.type(), std.sizes()) |
| |
| - name: normal(Tensor mean, Tensor std, Generator generator) |
| mean: at::zeros(grad.type(), mean.sizes()) |
| std: at::zeros(grad.type(), std.sizes()) |
| |
| - name: orgqr(Tensor self, Tensor input2) |
| self: not_implemented("orgqr") |
| input2: not_implemented("orgqr") |
| |
| - name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left, bool transpose) |
| self: not_implemented("ormqr") |
| input2: not_implemented("ormqr") |
| input3: not_implemented("ormqr") |
| |
| - name: permute(Tensor self, IntList dims) |
| self: permute_backwards(grad, dims) |
| |
| - name: poisson(Tensor self, Generator generator) |
| self: zeros_like(self) |
| |
| - name: potrf(Tensor self, bool upper) |
| self: potrf_backward(grad, upper, output) |
| |
| - name: potri(Tensor self, bool upper) |
| self: not_implemented("potri") |
| |
| - name: potrs(Tensor self, Tensor input2, bool upper) |
| self: not_implemented("potri") |
| input2: not_implemented("potri") |
| |
| - name: pow(Tensor self, Scalar exponent) |
| self: grad * exponent * self.pow(exponent.toDouble() - 1) |
| |
| - name: pow(Tensor self, Tensor exponent) |
| self: grad * exponent * self.pow(exponent - 1) |
| exponent: grad * self.pow(exponent) * self.log() |
| |
| - name: prod(Tensor self, int64_t dim, bool keepdim) |
| self: prod_backward(grad, self, result, dim, keepdim) |
| |
| - name: prod(Tensor self) |
| self: prod_backward(grad, self, result) |
| |
| - name: pstrf(Tensor self, bool upper, Scalar tol) |
| self: not_implemented("pstrf") |
| |
| - name: put_(Tensor self, Tensor index, Tensor source, bool accumulate) |
| self: grad.clone().put_(index, zeros_like(source), accumulate) |
| source: grad.take(index) |
| |
| - name: qr(Tensor self) |
| self: not_implemented("qr") |
| |
| - name: random_(Tensor self, int64_t from, int64_t to, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: random_(Tensor self, int64_t to, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: random_(Tensor self, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: reciprocal(Tensor self) |
| self: -grad * result * result |
| |
| - name: remainder(Tensor self, Scalar other) |
| self: grad |
| |
| - name: remainder(Tensor self, Tensor other) |
| self: grad |
| |
| - name: renorm(Tensor self, Scalar p, int64_t dim, Scalar maxnorm) |
| self: renorm_backward(grad, self, p, dim, maxnorm) |
| |
| - name: repeat(Tensor self, IntList repeats) |
| self: repeat_backward(grad, self.dim(), repeats) |
| |
| - name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) |
| input: RoiPooling2d_backward(input, rois, pooledHeight, pooledWidth, spatialScale, grad, result1) |
| |
| - name: round(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: rsqrt(Tensor self) |
| self: -0.5 * grad * result.pow(3) |
| |
| - name: scatter_(Tensor self, int64_t dim, Tensor index, Tensor src) |
| self: grad.clone().scatter_(dim, index, 0) |
| src: grad.gather(dim, index) |
| |
| - name: scatter_(Tensor self, int64_t dim, Tensor index, Scalar value) |
| self: grad.clone().scatter_(dim, index, 0) |
| |
| - name: scatter_add_(Tensor self, int64_t dim, Tensor index, Tensor src) |
| self: grad |
| src: grad.gather(dim, index) |
| |
| - name: sigmoid(Tensor self) |
| self: _sigmoid_backward(grad, result) |
| |
| - name: sign(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: sin(Tensor self) |
| self: grad * self.cos() |
| |
| - name: sinh(Tensor self) |
| self: grad * self.cosh() |
| |
| - name: sort(Tensor self, int64_t dim, bool descending) |
| self: select_backward(grad, dim, indices, self.sizes(), true) |
| |
| - name: split(Tensor self, int64_t split_size, int64_t dim) |
| self: split_backward(grads, split_size, dim, self.sizes(), self.type()) |
| |
| - name: split_with_sizes(Tensor self, IntList split_sizes, int64_t dim) |
| self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.type()) |
| |
| - name: sqrt(Tensor self) |
| self: grad / (2 * result) |
| |
| - name: squeeze(Tensor self) |
| self: unsqueeze_to(grad, self.sizes()); |
| |
| - name: squeeze(Tensor self, int64_t dim) |
| self: unsqueeze_to(grad, dim, self.sizes()) |
| |
| - name: std(Tensor self, bool unbiased) |
| self: var_backward(grad / (result * 2), self, unbiased) |
| |
| - name: std(Tensor self, int64_t dim, bool unbiased, bool keepdim) |
| self: var_backward(grad / (result * 2), self, dim, unbiased, keepdim) |
| |
| - name: sub(Tensor self, Scalar other, *, Scalar alpha) |
| self: grad |
| |
| - name: sub(Tensor self, Tensor other, *, Scalar alpha) |
| self: grad |
| other: -grad * alpha |
| |
| - name: sum(Tensor self) |
| self: grad.expand(self.sizes()) |
| |
| - name: sum(Tensor self, int64_t dim, bool keepdim) |
| self: sum_backward(grad, self.sizes(), dim, keepdim) |
| |
| - name: svd(Tensor self, bool some) |
| self: svd_backward(grads, self, some, res1, res2, res3) |
| |
| - name: symeig(Tensor self, bool eigenvectors, bool upper) |
| self: not_implemented("symeig") |
| |
| - name: t(Tensor self) |
| self: grad.t() |
| |
| - name: take(Tensor self, Tensor index) |
| self: zeros_like(self).put_(index, grad, true) |
| |
| - name: tan(Tensor self) |
| self: grad * (1 + result.pow(2)) |
| |
| - name: tanh(Tensor self) |
| self: _tanh_backward(grad, result) |
| |
| - name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted) |
| self: select_backward(grad, dim, indices, self.sizes(), true) |
| |
| - name: trace(Tensor self) |
| self: trace_backward(grad, self.sizes()) |
| |
| - name: transpose(Tensor self, int64_t dim0, int64_t dim1) |
| self: grad.transpose(dim0, dim1) |
| |
| - name: tril(Tensor self, int64_t diagonal) |
| self: grad.tril(diagonal) |
| |
| - name: triu(Tensor self, int64_t diagonal) |
| self: grad.triu(diagonal) |
| |
| - name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) |
| self, A: trtrs_backward(grads[0], grads[1], self, A, res1, upper, transpose, unitriangular, grad_input_mask) |
| |
| - name: trunc(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step) |
| self: unfold_backward(grad, self.sizes(), dimension, size, step) |
| |
| - name: uniform_(Tensor self, double from, double to, Generator generator) |
| self: zeros_like(grad) |
| |
| - name: _unique(Tensor self, bool sorted, bool return_inverse) |
| self: not_implemented("_unique") |
| |
| - name: _unsafe_view(Tensor self, IntList size) |
| self: grad.contiguous().view(self.sizes()) |
| |
| - name: unsqueeze(Tensor self, int64_t dim) |
| self: grad.squeeze(dim) |
| |
| - name: var(Tensor self, bool unbiased) |
| self: var_backward(grad, self, unbiased) |
| |
| - name: var(Tensor self, int64_t dim, bool unbiased, bool keepdim) |
| self: var_backward(grad, self, dim, unbiased, keepdim) |
| |
| - name: view(Tensor self, IntList size) |
| self: grad.contiguous().view(self.sizes()) |
| |
| - name: _s_where(Tensor condition, Tensor self, Tensor other) |
| self: where(condition, grad, zeros_like(grad)) |
| other: where(condition, zeros_like(grad), grad) |
| |
| - name: zero_(Tensor self) |
| self: zeros_like(grad) |
| |
| - name: _sparse_mask(Tensor self, SparseTensor mask) |
| self: not_implemented("_sparse_mask") |
| mask: not_implemented("_sparse_mask") |
| |
| - name: _standard_gamma(Tensor self, Generator generator) |
| self: grad * self._standard_gamma_grad(output) |
| |
| - name: _standard_gamma_grad(Tensor self, Tensor output) |
| self: not_implemented("_standard_gamma_grad") |
| |
| # NN |
| |
| - name: binary_cross_entropy_forward(Tensor self, Tensor target, Tensor weight, bool size_average, bool reduce) |
| self: binary_cross_entropy_backward(grad, self, target, weight, size_average, reduce) |
| |
| - name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) |
| weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) |
| |
| - name: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) |
| weight: embedding_bag_backward(grad, indices, offsets, result1, result2, weight.size(0), scale_grad_by_freq, mode, sparse) |
| |
| - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) |
| self: not_implemented("embedding_renorm") |
| |
| - name: kl_div_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: kl_div_backward(grad, self, target, size_average, reduce) |
| |
| - name: l1_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: l1_loss_backward(grad, self, target, size_average, reduce) |
| |
| - name: mse_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: mse_loss_backward(grad, self, target, size_average, reduce) |
| |
| - name: multi_margin_loss_forward(Tensor self, Tensor target, Scalar p, Scalar margin, Tensor weight, bool size_average, bool reduce) |
| self: multi_margin_loss_backward(grad, self, target, p, margin, weight, size_average, reduce) |
| |
| - name: multilabel_margin_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: multilabel_margin_loss_backward(grad, self, target, size_average, reduce, is_target) |
| |
| - name: nll_loss_forward(Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce) |
| self: nll_loss_backward(grad, self, target, weight, size_average, ignore_index, reduce, total_weight) |
| |
| - name: nll_loss2d_forward(Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce) |
| self: nll_loss2d_backward(grad, self, target, weight, size_average, ignore_index, reduce, total_weight) |
| |
| - name: smooth_l1_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: smooth_l1_loss_backward(grad, self, target, size_average, reduce) |
| |
| - name: soft_margin_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) |
| self: soft_margin_loss_backward(grad, self, target, size_average, reduce) |
| |
| - name: elu_forward(Tensor self, Scalar alpha, Scalar scale) |
| self: elu_backward(grad, alpha, scale, output) |
| |
| - name: glu_forward(Tensor self, int64_t dim) |
| self: glu_backward(grad, self, dim) |
| |
| - name: hardshrink_forward(Tensor self, Scalar lambd) |
| self: hardshrink_backward(grad, self, lambd) |
| |
| - name: hardtanh_forward(Tensor self, Scalar min_val, Scalar max_val) |
| self: hardtanh_backward(grad, self, min_val, max_val) |
| |
| - name: hardtanh_forward_(Tensor self, Scalar min_val, Scalar max_val) |
| self: hardtanh_backward(grad, output, min_val, max_val) |
| |
| - name: leaky_relu_forward(Tensor self, Scalar negative_slope) |
| self: leaky_relu_backward(grad, self, negative_slope) |
| |
| - name: leaky_relu_forward_(Tensor self, Scalar negative_slope) |
| self: leaky_relu_backward(grad, output, negative_slope) |
| |
| - name: log_sigmoid_forward(Tensor self) |
| self: log_sigmoid_backward(grad, self, buffer) |
| |
| - name: log_softmax_forward(Tensor self, int64_t dim) |
| self: log_softmax_backward(grad, self, dim, output) |
| |
| - name: prelu_forward(Tensor self, Tensor weight) |
| self, weight: prelu_backward(grad, self, weight, grad_input_mask) |
| |
| - name: rrelu_with_noise_forward(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) |
| self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training) |
| |
| - name: rrelu_with_noise_forward_(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) |
| self: rrelu_with_noise_backward(grad, output, noise, lower, upper, training) |
| |
| - name: softmax_forward(Tensor self, int64_t dim) |
| self: softmax_backward(grad, self, dim, output) |
| |
| - name: softplus_forward(Tensor self, Scalar beta, Scalar threshold) |
| self: softplus_backward(grad, self, beta, threshold, output) |
| |
| - name: softshrink_forward(Tensor self, Scalar lambd) |
| self: softshrink_backward(grad, self, lambd) |
| |
| - name: threshold_forward(Tensor self, Scalar threshold, Scalar value) |
| self: threshold_backward(grad, self, threshold, value) |
| |
| - name: threshold_forward_(Tensor self, Scalar threshold, Scalar value) |
| self: threshold_backward(grad, output, threshold, value) |
| |
| - name: reflection_pad1d_forward(Tensor self, IntList padding) |
| self: reflection_pad1d_backward(grad, self, padding) |
| |
| - name: reflection_pad2d_forward(Tensor self, IntList padding) |
| self: reflection_pad2d_backward(grad, self, padding) |
| |
| - name: replication_pad1d_forward(Tensor self, IntList padding) |
| self: replication_pad1d_backward(grad, self, padding) |
| |
| - name: replication_pad2d_forward(Tensor self, IntList padding) |
| self: replication_pad2d_backward(grad, self, padding) |
| |
| - name: replication_pad3d_forward(Tensor self, IntList padding) |
| self: replication_pad3d_backward(grad, self, padding) |
| |
| - name: upsample_linear1d_forward(Tensor self, IntList output_size) |
| self: upsample_linear1d_backward(grad, output_size, self.sizes()) |
| |
| - name: upsample_bilinear2d_forward(Tensor self, IntList output_size) |
| self: upsample_bilinear2d_backward(grad, output_size, self.sizes()) |
| |
| - name: upsample_trilinear3d_forward(Tensor self, IntList output_size) |
| self: upsample_trilinear3d_backward(grad, output_size, self.sizes()) |
| |
| - name: upsample_nearest1d_forward(Tensor self, int64_t scale_factor) |
| self: upsample_nearest1d_backward(grad, self, scale_factor) |
| |
| - name: upsample_nearest2d_forward(Tensor self, int64_t scale_factor) |
| self: upsample_nearest2d_backward(grad, self, scale_factor) |
| |
| - name: upsample_nearest3d_forward(Tensor self, int64_t scale_factor) |
| self: upsample_nearest3d_backward(grad, self, scale_factor) |
| |
| - name: adaptive_avg_pool2d_forward(Tensor self, IntList output_size) |
| self: adaptive_avg_pool2d_backward(grad, self) |
| |
| - name: adaptive_avg_pool3d_forward(Tensor self, IntList output_size) |
| self: adaptive_avg_pool3d_backward(grad, self) |
| |
| - name: adaptive_max_pool2d_forward(Tensor self, IntList output_size) |
| self: adaptive_max_pool2d_backward(grad, self, indices) |
| |
| - name: adaptive_max_pool3d_forward(Tensor self, IntList output_size) |
| self: adaptive_max_pool3d_backward(grad, self, indices) |
| |
| - name: avg_pool2d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) |
| self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad) |
| |
| - name: avg_pool3d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) |
| self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad) |
| |
| - name: fractional_max_pool2d_forward(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples) |
| self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, indices) |
| |
| - name: max_pool2d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) |
| self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices) |
| |
| - name: max_pool3d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) |
| self: max_pool3d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices) |
| |
| - name: max_unpool2d_forward(Tensor self, Tensor indices, IntList output_size) |
| self: max_unpool2d_backward(grad, self, indices, output_size) |
| |
| - name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding) |
| self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding) |
| |
| - name: thnn_batch_norm_forward(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps) |
| self, weight, bias: thnn_batch_norm_backward(grad.contiguous(), self, weight, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask) |
| |
| - name: thnn_batch_norm_backward(Tensor grad_output, Tensor self, Tensor weight, Tensor running_mean, Tensor running_var, bool training, double eps, Tensor save_mean, Tensor save_std, std::array<bool,3> output_mask) |
| self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask) |
| save_mean: not_implemented("thnn_batch_norm_backward save_mean") |
| save_std: not_implemented("thnn_batch_norm_backward save_std") |
| |
| - name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) |
| self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask) |
| |
| - name: thnn_conv_transpose2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask) |
| |
| - name: thnn_conv_transpose3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) |
| self, weight, bias: thnn_conv_transpose3d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, finput, fgrad_input, grad_input_mask) |
| |
| - name: thnn_conv_transpose3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask) |
| |
| - name: thnn_conv2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding) |
| self, weight, bias: thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) |
| |
| - name: thnn_conv2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, grad_input_mask) |
| |
| - name: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation) |
| self, weight: thnn_conv_depthwise2d_backward(grad.contiguous(), self, weight, kernel_size, stride, padding, dilation, grad_input_mask) |
| bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1) |
| |
| - name: thnn_conv_depthwise2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, std::array<bool,2> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, grad_input_mask) |
| |
| - name: thnn_conv3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding) |
| self, weight, bias: thnn_conv3d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) |
| |
| - name: thnn_conv3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask) |
| |
| - name: thnn_conv_dilated2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation) |
| self, weight, bias: thnn_conv_dilated2d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask) |
| |
| - name: thnn_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, grad_input_mask) |
| |
| - name: thnn_conv_dilated3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation) |
| self, weight, bias: thnn_conv_dilated3d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask) |
| |
| - name: thnn_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask) |
| |
| # NN double backwards support |
| |
| - name: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) |
| grad_output: adaptive_avg_pool2d(grad, { grad_output.size(-2), grad_output.size(-1) }) |
| self: zeros_like(self) |
| |
| - name: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) |
| grad_output: adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) }) |
| self: zeros_like(self) |
| |
| - name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) |
| grad_output: max_pool_double_backward(grad, indices, 2) |
| self: zeros_like(self) |
| |
| - name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) |
| grad_output: max_pool_double_backward(grad, indices, 3) |
| self: zeros_like(self) |
| |
| - name: avg_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) |
| grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad) |
| self: zeros_like(self) |
| |
| - name: avg_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) |
| grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad) |
| self: zeros_like(self) |
| |
| - name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Tensor output) |
| grad_output: elu_backward(grad, alpha, scale, output) |
| output: grad * grad_output * (output < 0).toType(grad.type()) |
| |
| - name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList output_size, Tensor indices) |
| grad_output: max_pool_double_backward(grad, indices, 2) |
| self: zeros_like(self) |
| |
| - name: glu_backward(Tensor grad_output, Tensor self, int64_t dim) |
| grad_output: glu_double_backward_grad_output(grad, self, dim) |
| self: glu_double_backward(grad, grad_output, self, dim) |
| |
| - name: hardshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) |
| grad_output: hardshrink_backward(grad, self, lambd) |
| self: zeros_like(grad) |
| |
| - name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) |
| grad_output: hardtanh_backward(grad, self, min_val, max_val) |
| self: zeros_like(grad) |
| |
| - name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce) |
| grad_output: kl_div_double_backward_grad_output(grad, self, target, size_average, reduce) |
| self: zeros_like(grad) |
| |
| - name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce) |
| grad_output: l1_loss_double_backward_grad_output(grad, self, target, size_average, reduce) |
| self: zeros_like(grad) |
| |
| - name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) |
| grad_output: log_sigmoid_backward(grad, self, buffer) |
| self: log_sigmoid_double_backward(grad * grad_output, self) |
| |
| - name: log_softmax_backward(Tensor grad_output, Tensor self, int64_t dim, Tensor output) |
| grad_output: grad - (grad * output.exp()).sum(dim, true) |
| self: log_softmax_double_backward(grad, grad_output, dim, output) |
| |
| - name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) |
| grad_output: leaky_relu_backward(grad, self, negative_slope) |
| self: zeros_like(grad) |
| |
| - name: max_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices) |
| grad_output: max_pool_double_backward(grad, indices, 2); |
| self: zeros_like(self) |
| |
| - name: max_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices) |
| grad_output: max_pool_double_backward(grad, indices, 3); |
| self: zeros_like(self) |
| |
| - name: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, IntList output_size) |
| grad_output: max_unpool2d(grad, indices, output_size) |
| self: zeros_like(self) |
| |
| - name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce) |
| grad_output: mse_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce) |
| self: mse_loss_double_backward(grad * grad_output, self, size_average, reduce) |
| |
| - name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce, Tensor total_weight) |
| grad_output: nll_loss(grad, target, weight, size_average, ignore_index, reduce) |
| self: zeros_like(grad) |
| |
| - name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce, Tensor total_weight) |
| grad_output: nll_loss2d(grad, target, weight, size_average, ignore_index, reduce) |
| self: zeros_like(grad) |
| |
| - name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight, std::array<bool,2> output_mask) |
| grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight, grad_input_mask) |
| |
| - name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) |
| grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training) |
| self: zeros_like(grad) |
| |
| - name: reflection_pad1d_backward(Tensor grad_output, Tensor self, IntList padding) |
| grad_output: reflection_pad1d(grad, padding) |
| self: zeros_like(self) |
| |
| - name: reflection_pad2d_backward(Tensor grad_output, Tensor self, IntList padding) |
| grad_output: reflection_pad2d(grad, padding) |
| self: zeros_like(self) |
| |
| - name: replication_pad1d_backward(Tensor grad_output, Tensor self, IntList padding) |
| grad_output: replication_pad1d(grad, padding) |
| self: zeros_like(self) |
| |
| - name: replication_pad2d_backward(Tensor grad_output, Tensor self, IntList padding) |
| grad_output: replication_pad2d(grad, padding) |
| self: zeros_like(self) |
| |
| - name: replication_pad3d_backward(Tensor grad_output, Tensor self, IntList padding) |
| grad_output: replication_pad3d(grad, padding) |
| self: zeros_like(self) |
| |
| - name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce) |
| grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce) |
| self: smooth_l1_loss_double_backward(grad * grad_output, self, target, size_average, reduce) |
| |
| - name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) |
| grad_output: softplus_backward(grad, self, beta, threshold, output) |
| self: softplus_double_backward(grad * grad_output, self, beta, threshold) |
| |
| - name: softmax_backward(Tensor grad_output, Tensor self, int64_t dim, Tensor output) |
| grad_output: softmax_backward(grad, self, dim, output) |
| self: softmax_double_backward(grad, grad_output, dim, output) |
| |
| - name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, bool size_average, bool reduce) |
| grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, size_average, reduce) |
| self: soft_margin_loss_double_backward(grad * grad_output, self, target, size_average, reduce) |
| |
| - name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) |
| grad_output: softshrink_backward(grad, self, lambd) |
| self: zeros_like(grad) |
| |
| - name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold, Scalar value) |
| grad_output: threshold_backward(grad, self, threshold, value) |
| self: zeros_like(grad) |
| |
| - name: upsample_linear1d_backward(Tensor grad_output, IntList output_size, IntList input_size) |
| grad_output: upsample_linear1d(grad, output_size) |
| |
| - name: upsample_bilinear2d_backward(Tensor grad_output, IntList output_size, IntList input_size) |
| grad_output: upsample_bilinear2d(grad, output_size) |
| |
| - name: upsample_trilinear3d_backward(Tensor grad_output, IntList output_size, IntList input_size) |
| grad_output: upsample_trilinear3d(grad, output_size) |
| |
| - name: upsample_nearest1d_backward(Tensor grad_output, Tensor self, int64_t scale_factor) |
| grad_output: upsample_nearest1d(grad, scale_factor) |
| self: zeros_like(grad) |
| |
| - name: upsample_nearest2d_backward(Tensor grad_output, Tensor self, int64_t scale_factor) |
| grad_output: upsample_nearest2d(grad, scale_factor) |
| self: zeros_like(grad) |
| |
| - name: upsample_nearest3d_backward(Tensor grad_output, Tensor self, int64_t scale_factor) |
| grad_output: upsample_nearest3d(grad, scale_factor) |
| self: zeros_like(grad) |
| |
| - name: _sigmoid_backward(Tensor grad_output, Tensor output) |
| grad_output: _sigmoid_backward(grad, output) |
| output: grad * grad_output * (-2 * output + 1) |
| |
| - name: _tanh_backward(Tensor grad_output, Tensor output) |
| grad_output: _tanh_backward(grad, output) |
| output: -2 * output * grad * grad_output |
| |
| # cudnn |
| |
| - name: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) |
| self, weight, bias: cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) |
| |
| - name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask) |
| |
| - name: cudnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) |
| self, weight, bias: cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) |
| |
| - name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) |
| |
| # The above backward definitions are equivalent to the definitions below. Why do we bundle |
| # everything up? It's because it's more convenient to define double backwards |
| # when there is a single function that manages everything. |
| # |
| # Unfortuantely, there's one downside to not doing it all in one day: we |
| # unconditionally save input and weight, even if weight/input gradients are not |
| # being computed. That's too bad. |
| # |
| # input: cudnn_convolution_backward_input(input.sizes(), grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic) |
| # weight: cudnn_convolution_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic) |
| # bias: cudnn_convolution_backward_bias(grad.contiguous()) |
| # |
| # input: cudnn_convolution_transpose_backward_input(grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic) |
| # weight: cudnn_convolution_transpose_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic) |
| # bias: cudnn_convolution_backward_bias(grad.contiguous()) |
| |
| - name: cudnn_grid_sampler(Tensor self, Tensor grid) |
| self, grid: cudnn_grid_sampler_backward(self, grid, grad) |
| |
| - name: cudnn_affine_grid_generator(Tensor theta, int64_t N, int64_t C, int64_t H, int64_t W) |
| theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) |
| |
| # NB: Why is the backwards here so complicated? CuDNN cannot be used to compute |
| # backward in evaluation mode, because the math for backward in evaluation mode |
| # is different (since the forward math is different), and CuDNN does not support |
| # it. And in any case, you shouldn't be using this bn in evaluation mode, |
| # because it should be merged into the previous convolution (left for future |
| # work.) |
| # NB2: The quotes around the gradient are needed to appease YAML parsing rules. |
| - name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) |
| input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)" |
| |
| # HACK: save_mean and save_var are going to be passed in as |
| # requires_grad variables (even though we'll never backprop through |
| # them) so we need to prevent the unpacking from triggering an error. |
| - name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon) |
| save_mean: not_implemented("cudnn_batch_norm_backward save_mean") |
| save_var: not_implemented("cudnn_batch_norm_backward save_var") |
| input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) |
| |
| # nnpack |
| |
| - name: nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) |
| input: nnpack_spatial_convolution_backward_input(input, grad, weight, kW, kH, padW, padH) |
| weight: nnpack_spatial_convolution_backward_weight(input, weight.sizes(), grad, kW, kH, padW, padH) |
| bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1) |
| |
| - name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, Tensor dropout_state) |
| input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" |