| # Defines derivative formulas and Python signatures of methods on Variable |
| # |
| # Note about possibly confusing nomenclature: An 'output gradient' is the |
| # gradient of an output of a forward function. Output gradients are used as |
| # the inputs to backward functions. `grads` is a vector of output gradients, |
| # and `grad == grads[0]`, in all the derivative formulas in this file. |
| # An 'input gradient' is the gradient of an input to a forward function. |
| # Input gradients are the outputs of backward functions, corresponding to the |
| # input names included in the derivative formulas defined in this file. |
| # Also, every time we talk computing "gradient" we actually mean computing |
| # the vector jacobian product using the given 'output gradient' as the vector. |
| # |
| # Each entry consists of: |
| # - A 'name', which specifies the ATen name of the function you |
| # are defining derivatives for, and an argument specification. |
| # - An optional 'dispatch' entry which can be used to specify |
| # per-autograd dispatch key derivatives. If this entry is not |
| # specified, then the gradient entries will be taken as the |
| # default gradients (i.e. registered for every backward dispatch |
| # key). (see _test_autograd_multiple_dispatch for an example |
| # of how to register separate derivates for different dispatch keys). |
| # The list of allowed dispatch keys (in addition to 'Default' which |
| # represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS. |
| # - One or more gradients entries, mapping differentiable input |
| # names to a formula specifying how to compute its gradient. |
| # Note that a single gradient entry can specify the gradient |
| # formula for multiple input names, by specifying a key |
| # "input1, input2" (see atan2 for an example). |
| # - An argument can be flagged as 'non_differentiable'. |
| # - Optional entry with key 'output_differentiability' and value a list of the |
| # same length as the number of outputs from the forward function. The list |
| # should contain only booleans, specifying whether each of the output Tensor |
| # is differentiable. |
| # If it is not specified for a function that returns multiple elements but |
| # uses `grad` instead of `grads[idx]`, then all but the first output will |
| # be marked as non-differentiable. |
| # If None of the output is differentiable, you can also add the function |
| # name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list. |
| # |
| # There are two cases for Tensor and TensorList arguments here: |
| # - If that argument is differentiable, in the sense that a gradient with respect |
| # to that argument could exist. You should either: |
| # - Specify the formula for that gradient |
| # - Specify not_implemented("function_name") as a formula to say that this is not |
| # implemented yet (but might be in the future and the user can request that on an issue) |
| # - If that argument is not differentiable, because it is not a floating point dtype or the |
| # function is not differentiable with respect to that argument for |
| # example. You should either: |
| # - Do not specify any formula for this argument |
| # - Specify explicitly that this argument is "non_differentiable". Note that in this case, |
| # we trust you that this argument will never have requires_grad=True and it will be silently |
| # ignored if it does. |
| # |
| # If a function has out-of-place and in-place variants, then the derivative |
| # definition for the in-place variant is optional. It will default to the |
| # definition for the out-of-place variant. Note that _out variants are never |
| # differentiable. |
| # |
| # Gradient expressions are standard C++ expressions operating on ATen |
| # variables. In a gradient expression, the following variables are in |
| # scope: |
| # |
| # - 'grad', the gradient of the output (often spelled grad_output |
| # in Python) which we are going to left-multiply. |
| # |
| # When a function returns multiple *differentiable* outputs, |
| # you can refer to the gradients of each outputs using 'grads', |
| # e.g., 'grads[0]', 'grads[1]'. |
| # |
| # When a function returns multiple *differentiable* outputs that |
| # are named, you can refer to the gradients of each outputs using |
| # 'grad_{name}', e.g., 'grad_x', 'grad_y'. |
| # |
| # When a function returns *one* differentiable output (the |
| # first output) and some more nondifferentiable outputs, |
| # you MUST refer to the gradient of the differentiable output with |
| # 'grad' (this case is special-cased in our code generation). |
| # |
| # Note that the number of differentibale outputs can be modified by the |
| # 'output_differentiability' entry (see above). |
| # |
| # Across a differentiable function's derivatives set, it is not |
| # permitted to mix the use of "grad", "grads", and |
| # "grad_{name}". You must be consistent for that differentiable |
| # function. |
| # |
| # - Any of the input arguments, tensor or non-tensor, including |
| # argument names that only appear in Declarations.yaml, e.g. 'output'. |
| # |
| # - 'result', representing the result of evaluating the forward |
| # expression for ATen native function declarations. If the forward |
| # expression outputs a tuple, use 'resultX' instead to access the |
| # X-th entry |
| # |
| # - 'grad_input_mask', a std::array<bool, n>, specifies which input |
| # gradients are actually needed. For example, in the entry |
| # `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size |
| # two array, where `grad_input_mask[0]` is true if `input0` requires |
| # grad, and `grad_input_mask[1]` is true if `input1` requires grad. |
| # |
| # (NB: if your function computes gradient for a list of tensors, |
| # the `grad_input_mask` will only have a single entry for the list |
| # specifying if either zero or at least one tensor from the list requires |
| # grad. If we want to support more fine-grained signalling, |
| # we'll need some alternate variable which is not a std::array) |
| # |
| # - 'retain_variables', a bool which is true if a user has specified |
| # that saved variables should be retained in case the backwards is |
| # run again later. This allows an optimization where we can |
| # destroy saved buffers if we know variables are not going to be retained, |
| # e.g., it is used by _cudnn_rnn |
| # |
| # If you need a complex expression, e.g., with local variables, |
| # write a _backward function in torch/csrc/autograd/FunctionsManual.cpp |
| # and invoke it from here. By the way, go read |
| # https://github.com/zdevito/ATen/issues/163; this describes an |
| # important hazard that occurs when porting backwards from Python to C++ |
| # |
| # Double backwards gradient expressions can be somewhat confusing; |
| # the most important thing to remember is: (1) you need to define a |
| # derivative formula for every input, including inputs named things |
| # like 'grad_output', and (2) the gradient to multiply with is always |
| # called 'grad' (even though it really is a grad-grad). |
| # |
| # You can also add forward derivative definition by defining a formula for |
| # a returned value (in general "result" if the name is not specified). This |
| # formula works the same way as the backward one and advanced implementations |
| # should also be placed in the FunctionsManual file. |
| # This formula should compute a single Jacobian vector product using the (primal) |
| # value of the argument "foo_p", its forward grad "foo_t" and the result of the |
| # function as "result". |
| # Note that the forward derivative can be automatically generated in two cases: |
| # - if your function is linear (NOT affine or multi-linear), then you can |
| # specify so by just using the string "auto_linear" for the formula. |
| # - if your function is applied element wise (and has a single input), you |
| # can specify so by just using the string "auto_element_wise" for the formula. |
| # |
| # Note that to avoid unpacking overhead, functions taking TensorList as inputs |
| # will always have their forward grad formula called. This function is responsible |
| # to check if any computation is needed and should return an undefined Tensor when |
| # there is nothing to do. You can check "cat_forward" for a full example. |
| # |
| # NB: There are a number of gradient definitions in here which are bogus |
| # (implemented using zeros_like). These gradients are (hopefully) not |
| # used by our frontend. You MUST check the frontend code; search for |
| # OpName.apply to see if it's still using a legacy Python style API. |
| # |
| # Note: Returning views. |
| # The following cases exist: |
| # - If a function returns no view, it can have arbitrary outputs. |
| # - If a function return at least one Tensor that is a differentiable view |
| # of one of its input: |
| # - If there is only one differentiable output, this Tensor is marked as a |
| # differentiable view. (alias or transpose for example) |
| # - If there are more than one differentiable output, by default all the views are |
| # marked as differentiable views and created with allow_rebase_history=false. |
| # Meaning that any inplace operation on it will raise an error. (unbind for example) |
| # |
| # Notes about undefined output gradients: |
| # All backward functions must support all combinations of undefined output |
| # gradient Tensors, where `grad[i].defined() == false`. Depending on the |
| # number of input and output grads your derivative formula uses, code |
| # generation may automatically add some level of undefined grad support, |
| # according to these three cases: |
| # |
| # * 1 input grad and 1 output grad: |
| # Complete undefined grad support is automatically added, so you |
| # shouldn't have to think about it, unless there is a bug in the code |
| # generation. |
| # |
| # * 1 input grad and multiple output grads: |
| # Undefined grad support is automatically added ONLY in the case where |
| # all output grads are undefined. You will have to add explicit support |
| # for cases where a subset of output grads is undefined. |
| # |
| # * multiple input grads: |
| # No automatic support, so you will need to add it. |
| # |
| # If your derivative formula uses more than one output grad, it is usually |
| # preferable to add undefined grad support in the backward function itself |
| # (if you're using one), rather than in the derivative formula in this file. |
| # |
| # Undefined Tensors are created with the default constructor `at::Tensor()`. |
| # It is an efficient way to represent a Tensor filled with zeros because |
| # the Tensor holds no sizing information and no Storage data is allocated. |
| # But consequentially, Tensor operations cannot be performed on them. |
| # Therefore, your backward function should treat an undefined output grad as |
| # a zero, and it needs to be a special case. |
| # |
| # If all output grads are undefined, then it should be correct for the |
| # backward function to return undefined input grads. Since we use the chain |
| # rule, output grads equal to zero should result in input grads equal to zero, |
| # unless there is some rare special case. |
| # |
| # If a subset of output grads is undefined, then it may be acceptable for |
| # the backward function to return undefined input grads--it depends on the |
| # specific function, so you'll have to determine that yourself. If returning |
| # an undefined Tensor is correct for a given input grad, it is also logically |
| # correct to return a defined grad full of zeros, but that would not be |
| # preferable since it would be less efficient. |
| # |
| # NB: The parameter names here MUST be consistent with the parameter names |
| # in native_functions.yaml |
| - name: abs(Tensor self) -> Tensor |
| self: grad * self.sgn() |
| result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn()) |
| |
| - name: acos(Tensor self) -> Tensor |
| self: grad * -((-self * self + 1).rsqrt()).conj() |
| result: auto_element_wise |
| |
| - name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) |
| result: self_t + maybe_multiply(other_t, alpha) |
| |
| - name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| result: self_t.clone() |
| |
| - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) |
| batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj()) |
| result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) |
| |
| - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) |
| tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) |
| result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) |
| |
| - name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) |
| tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) |
| result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value) |
| |
| - name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha) |
| mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) |
| result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) |
| |
| - name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta) |
| mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha) |
| mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) |
| |
| - name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj()) |
| vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj()) |
| result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha) |
| |
| - name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj()) |
| vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj()) |
| result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) |
| |
| - name: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor |
| theta: affine_grid_generator_backward(grad, size, align_corners) |
| |
| - name: alias(Tensor(a) self) -> Tensor(a) |
| self: grad |
| result: self_t |
| |
| - name: angle(Tensor self) -> Tensor |
| self: angle_backward(grad, self) |
| result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj()) |
| |
| # The four items below are necessary because TensorIterator doesn't work on |
| # Variables (codegen does not unwrap the input Tensor for all() and any() ). |
| - name: any(Tensor self) -> Tensor |
| output_differentiability: [False] |
| |
| - name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor |
| output_differentiability: [False] |
| |
| - name: _is_all_true(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: _is_any_true(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: all(Tensor self) -> Tensor |
| output_differentiability: [False] |
| |
| - name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor |
| output_differentiability: [False] |
| |
| - name: acosh(Tensor self) -> Tensor |
| # Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case) |
| self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()" |
| result: auto_element_wise |
| |
| - name: acosh_(Tensor(a!) self) -> Tensor(a!) |
| self: not_implemented("inplace version of acosh") |
| |
| - name: asinh(Tensor self) -> Tensor |
| self: grad * (self.pow(2) + 1).rsqrt().conj() |
| result: auto_element_wise |
| |
| - name: asinh_(Tensor(a!) self) -> Tensor(a!) |
| self: not_implemented("inplace version of asinh") |
| |
| - name: atanh(Tensor self) -> Tensor |
| self: grad * 1 / (1 - self.pow(2)).conj() |
| result: auto_element_wise |
| |
| - name: atanh_(Tensor(a!) self) -> Tensor(a!) |
| self: not_implemented("inplace version of atanh") |
| |
| - name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) |
| self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) |
| result: auto_linear |
| |
| - name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) |
| self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) |
| result: auto_linear |
| |
| - name: asin(Tensor self) -> Tensor |
| self: grad * (-self * self + 1).rsqrt().conj() |
| result: auto_element_wise |
| |
| - name: atan(Tensor self) -> Tensor |
| self: grad / (self * self + 1).conj() |
| result: auto_element_wise |
| |
| - name: atan2(Tensor self, Tensor other) -> Tensor |
| self, other: atan2_backward(grad, self, other, grad_input_mask) |
| result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2)) |
| |
| - name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj()) |
| batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj()) |
| result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha) |
| |
| - name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| p: zeros_like(p) |
| result: self_t.zero_() |
| |
| - name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: bmm(Tensor self, Tensor mat2) -> Tensor |
| self: grad.bmm(mat2.transpose(1, 2).conj()) |
| mat2: self.transpose(1, 2).conj().bmm(grad) |
| result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t) |
| |
| - name: matmul(Tensor self, Tensor other) -> Tensor |
| self, other: matmul_backward(grad, self, other, grad_input_mask) |
| |
| - name: cat(Tensor[] tensors, int dim=0) -> Tensor |
| tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim) |
| result: cat_jvp(tensors, dim) |
| |
| - name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: ceil(Tensor self) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: cholesky(Tensor self, bool upper=False) -> Tensor |
| self: cholesky_backward(grad, upper, result) |
| |
| - name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) |
| self: cholesky_backward(grad, upper, L) |
| L: cholesky_jvp(self_t, L, upper) |
| |
| - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor |
| self, input2: cholesky_solve_backward(grad, self, input2, result, upper) |
| result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper) |
| |
| - name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor |
| self: cholesky_inverse_backward(grad, self, upper, result) |
| result: cholesky_inverse_jvp(self_p, self_t, result, upper) |
| |
| # For clamp, gradient is not defined at the boundaries. But empirically it's helpful |
| # to be able to get gradient on min and max, so we return the subgradient 1 for these cases. |
| - name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor |
| self: clamp_backward(grad, self, min, max) |
| min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask) |
| result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t) |
| |
| - name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor |
| self: clamp_backward(grad, self, min, max) |
| result: auto_element_wise |
| |
| - name: clamp_min(Tensor self, Scalar min) -> Tensor |
| self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) |
| result: auto_element_wise |
| |
| - name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor |
| self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) |
| min: where(self < min, grad, at::scalar_tensor(0., grad.options())) |
| result: where(self_p >= min_p, self_t, min_t) |
| |
| - name: clamp_max(Tensor self, Scalar max) -> Tensor |
| self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) |
| result: auto_element_wise |
| |
| - name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor |
| self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) |
| max: where(self > max, grad, at::scalar_tensor(0., grad.options())) |
| result: where(self_p <= max_p, self_t, max_t) |
| |
| - name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor |
| self: grad |
| result: auto_linear |
| |
| - name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor |
| self: _to_copy_backward(grad, self.options()) |
| result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format) |
| # The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype) |
| # (If dtype IS nullopt, we rely on the regular check that any input requires grad). |
| output_differentiability: ["!dtype || isDifferentiableType(*dtype)"] |
| |
| - name: _coalesce(Tensor self) -> Tensor |
| self: grad |
| |
| - name: complex(Tensor real, Tensor imag) -> Tensor |
| real: at::real(grad) |
| imag: at::imag(grad) |
| result: at::complex(real_t, imag_t) |
| |
| - name: polar(Tensor abs, Tensor angle) -> Tensor |
| abs, angle: polar_backward(grad, result) |
| result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos()) |
| |
| - name: _conj(Tensor(a) self) -> Tensor(a) |
| self: grad.conj() |
| result: self_t.conj() |
| |
| - name: _neg_view(Tensor(a) self) -> Tensor(a) |
| self: grad.neg() |
| result: self_t._neg_view() |
| |
| - name: _conj_physical(Tensor self) -> Tensor |
| self: grad.conj_physical() |
| result: self_t.conj_physical() |
| |
| - name: conj_physical_(Tensor(a!) self) -> Tensor(a!) |
| self: grad.conj_physical() |
| result: self_t.conj_physical_() |
| |
| - name: copysign.Tensor(Tensor self, Tensor other) -> Tensor |
| self: copysign_tensor_self_backward(grad, self, result) |
| other: zeros_like(other) |
| result: copysign_tensor_self_backward(self_t, self_p, result) |
| |
| - name: copysign.Scalar(Tensor self, Scalar other) -> Tensor |
| self: copysign_tensor_self_backward(grad, self, result) |
| result: auto_element_wise |
| |
| - name: cos(Tensor self) -> Tensor |
| self: grad * -self.sin().conj() |
| result: auto_element_wise |
| |
| - name: cosh(Tensor self) -> Tensor |
| self: grad * self.sinh().conj() |
| result: auto_element_wise |
| |
| - name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor |
| output_differentiability: [False] |
| |
| - name: count_nonzero(Tensor self, int? dim=None) -> Tensor |
| output_differentiability: [False] |
| |
| - name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor |
| self: at::linalg_cross(other.conj(), grad, dim) |
| other: at::linalg_cross(grad, self.conj(), dim) |
| result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)" |
| |
| - name: logcumsumexp(Tensor self, int dim) -> Tensor |
| self: logcumsumexp_backward(grad, self, result, dim) |
| |
| - name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor |
| self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result) |
| result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())" |
| |
| - name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor |
| self: cumsum_backward(grad.to(self.scalar_type()), dim) |
| result: auto_linear |
| |
| - name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) |
| self: cummaxmin_backward(grad, self, indices, dim) |
| values: self_t.gather(dim, indices) |
| |
| - name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) |
| self: cummaxmin_backward(grad, self, indices, dim) |
| values: self_t.gather(dim, indices) |
| |
| - name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor |
| self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) |
| log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) |
| |
| - name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) |
| log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) |
| |
| - name: deg2rad(Tensor self) -> Tensor |
| self: deg2rad_backward(grad) |
| result: auto_element_wise |
| |
| - name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) |
| A: linalg_det_backward(grad, result, A, LU, pivots) |
| result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex()) |
| output_differentiability: [True, False, False] |
| |
| - name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) |
| A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots) |
| sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex()) |
| output_differentiability: [True, True, False, False] |
| |
| - name: block_diag(Tensor[] tensors) -> Tensor |
| tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors)) |
| result: block_diag_jvp(tensors) |
| |
| - name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor |
| self: grad.diagonal(offset, dim1, dim2) |
| result: auto_linear |
| |
| - name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) |
| self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2) |
| result: auto_linear |
| |
| - name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor |
| grad_output: grad.diagonal(offset, dim1, dim2) |
| result: auto_linear |
| |
| - name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor |
| self: norm_backward(grad, self - other, p, result) |
| other: -norm_backward(grad, self - other, p, result) |
| result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false) |
| |
| # The backward formula is done in this order to improve numerical stability |
| # of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414 |
| # Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later |
| - name: div.Tensor(Tensor self, Tensor other) -> Tensor |
| self: div_tensor_self_backward(grad, other, self.scalar_type()) |
| other: div_tensor_other_backward(grad, self, other) |
| result: (self_t - other_t * result) / other_p |
| |
| - name: div.Scalar(Tensor self, Scalar other) -> Tensor |
| self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type()) |
| result: self_t / other |
| |
| - name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor |
| self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) |
| other: div_tensor_other_backward(grad, self, other, rounding_mode) |
| result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" |
| |
| - name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor |
| self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type(), rounding_mode) |
| result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other" |
| |
| - name: dot(Tensor self, Tensor tensor) -> Tensor |
| self: grad * tensor.conj() |
| tensor: grad * self.conj() |
| result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t) |
| |
| - name: vdot(Tensor self, Tensor other) -> Tensor |
| self: grad.conj() * other |
| other: grad * self |
| result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t) |
| |
| - name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) |
| self: _fused_dropout_backward(grad, result1, p) |
| |
| - name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) |
| input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))" |
| result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t" |
| |
| - name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor |
| grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)" |
| mask: 'not_implemented("native_dropout_backward: mask")' |
| |
| - name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: erf(Tensor self) -> Tensor |
| self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad |
| result: auto_element_wise |
| |
| - name: erfc(Tensor self) -> Tensor |
| self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad |
| result: auto_element_wise |
| |
| - name: special_erfcx(Tensor self) -> Tensor |
| self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad |
| result: auto_element_wise |
| |
| - name: erfinv(Tensor self) -> Tensor |
| self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad |
| result: auto_element_wise |
| |
| - name: exp(Tensor self) -> Tensor |
| self: grad * result.conj() |
| result: auto_element_wise |
| |
| - name: exp2(Tensor self) -> Tensor |
| self: grad * result.conj() * M_LN2 |
| result: auto_element_wise |
| |
| - name: expm1(Tensor self) -> Tensor |
| self: grad * (result + 1) |
| result: auto_element_wise |
| |
| # TODO: this derivative is not SymInt safe, need sum_to support |
| - name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) |
| self: at::sum_to(grad, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) |
| self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) |
| |
| - name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) |
| self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) |
| |
| - name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor |
| self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) |
| self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask) |
| |
| - name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor |
| self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) |
| self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) |
| |
| - name: fill.Scalar(Tensor self, Scalar value) -> Tensor |
| self: zeros_like(grad) |
| result: at::fill(self_t, 0) |
| |
| - name: fill.Tensor(Tensor self, Tensor value) -> Tensor |
| self: zeros_like(grad) |
| value: grad.sum() |
| result: at::fill(self_t, value_t) |
| |
| - name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.fill_(0) |
| |
| - name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) |
| self: zeros_like(grad) |
| value: grad.sum() |
| result: self_t.fill_(value_t) |
| |
| - name: floor(Tensor self) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: fmod.Scalar(Tensor self, Scalar other) -> Tensor |
| self: grad |
| result: auto_element_wise |
| |
| - name: fmod.Tensor(Tensor self, Tensor other) -> Tensor |
| self: grad |
| other: -grad * self.div(other, /*rounding_mode=*/"trunc") |
| result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc") |
| |
| - name: frac(Tensor self) -> Tensor |
| self: grad |
| result: self_t |
| |
| - name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) |
| self: grad / exponent.exp2() |
| mantissa: self_t / exponent.exp2() |
| |
| - name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor |
| self: gather_backward(grad, self, dim, index, sparse_grad) |
| index: non_differentiable |
| result: auto_linear |
| |
| - name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: geqrf(Tensor self) -> (Tensor a, Tensor tau) |
| self: not_implemented("geqrf") |
| |
| - name: indices(Tensor(a) self) -> Tensor(a) |
| output_differentiability: [False] |
| |
| - name: _indices(Tensor(a) self) -> Tensor(a) |
| output_differentiability: [False] |
| |
| - name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor |
| input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple<Tensor, Tensor>()" |
| |
| - name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor |
| input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple<Tensor, Tensor>()" |
| |
| # See NOTE [ grid_sample CPU fallback ] |
| - name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor |
| input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple<Tensor, Tensor>()" |
| |
| - name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: hardsigmoid(Tensor self) -> Tensor |
| self: hardsigmoid_backward(grad, self) |
| result: auto_element_wise |
| |
| - name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor |
| output_differentiability: [False] |
| |
| - name: hardswish(Tensor self) -> Tensor |
| self: hardswish_backward(grad, self) |
| result: auto_element_wise |
| |
| - name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor |
| grad_output: hardswish_backward(grad, self) |
| self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options())) |
| result: "hardswish_backward(grad_output_t, self_p) |
| + at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))" |
| |
| - name: hypot(Tensor self, Tensor other) -> Tensor |
| self: grad * self / result |
| other: grad * other / result |
| result: self_t * self_p / result + other_t * other_p / result |
| |
| - name: i0(Tensor self) -> Tensor |
| self: grad * at::special_i1(self) |
| result: auto_element_wise |
| |
| - name: special_i0e(Tensor self) -> Tensor |
| self: grad * (at::special_i1e(self) - self.sgn() * result) |
| result: auto_element_wise |
| |
| - name: special_i1(Tensor self) -> Tensor |
| self: i1_backward(grad, self, result) |
| result: auto_element_wise |
| |
| - name: special_i1e(Tensor self) -> Tensor |
| self: i1e_backward(grad, self, result) |
| result: auto_element_wise |
| |
| - name: igamma(Tensor self, Tensor other) -> Tensor |
| self: 'not_implemented("igamma: input")' |
| other: grad * exp((self - 1) * log(other) - other - lgamma(self)) |
| |
| - name: igammac(Tensor self, Tensor other) -> Tensor |
| self: 'not_implemented("igammac: input")' |
| other: -grad * exp((self - 1) * log(other) - other - lgamma(self)) |
| |
| - name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor |
| self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad) |
| result: auto_linear |
| |
| - name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor |
| self: grad |
| # The case source.dim() == 0 is necessary to support scalar tensors of the form |
| # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), |
| # This is because source is not broadcastable to index, as source.dim() < index.dim() |
| source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)" |
| index: non_differentiable |
| result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha)) |
| |
| - name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor |
| self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result) |
| index: non_differentiable |
| |
| - name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor |
| self: grad.index_fill(dim, index, 0) |
| # The case source.dim() == 0 is necessary to support scalar tensors of the form |
| # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), |
| # This is because source is not broadcastable to index, as source.dim() < index.dim() |
| source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))" |
| index: non_differentiable |
| result: self_t.index_copy(dim, index, source_t) |
| |
| - name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor |
| self: grad.index_fill(dim, index, 0) |
| index: non_differentiable |
| result: self_t.index_fill(dim, index, 0) |
| |
| - name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor |
| self: grad.index_fill(dim, index, 0) |
| value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum() |
| index: non_differentiable |
| result: self_t.index_fill(dim, index, value_t) |
| |
| - name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor |
| self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" |
| values: grad.index(indices) |
| result: self_t.index_put(indices, values_t, accumulate) |
| |
| - name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) |
| self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" |
| values: grad.index(indices) |
| result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe) |
| |
| - name: index_select(Tensor self, int dim, Tensor index) -> Tensor |
| self: index_select_backward_symint(grad, self.sym_sizes(), dim, index) |
| index: non_differentiable |
| result: auto_linear |
| |
| - name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) |
| A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH())) |
| inverse: -at::matmul(at::matmul(inverse, A_t), inverse) |
| output_differentiability: [True, False] |
| |
| - name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor |
| self: pinv_backward(grad, result, self) |
| result: pinv_jvp(self_p, result, self_t) |
| |
| - name: isnan(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor |
| self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())" |
| end: grad * weight.conj() |
| result: at::lerp(self_t, end_t, weight) |
| |
| - name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor |
| self: grad * (1 - weight).conj() |
| end: grad * weight.conj() |
| weight: grad * (end - self).conj() |
| result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p) |
| |
| - name: lgamma(Tensor self) -> Tensor |
| self: grad * digamma(self) |
| result: auto_element_wise |
| |
| - name: digamma(Tensor self) -> Tensor |
| self: grad * polygamma(1, self) |
| result: auto_element_wise |
| |
| - name: polygamma(int n, Tensor self) -> Tensor |
| self: grad * polygamma(n + 1, self) |
| result: auto_element_wise |
| |
| - name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) |
| self: grad * polygamma(n + 1, self) |
| result: self_t.mul_(polygamma(n + 1, original_self_p)) |
| |
| - name: log(Tensor self) -> Tensor |
| self: grad.div(self.conj()) |
| result: auto_element_wise |
| |
| - name: log10(Tensor self) -> Tensor |
| self: grad / (self.conj() * 2.3025850929940456) |
| result: auto_element_wise |
| |
| - name: log1p(Tensor self) -> Tensor |
| self: log1p_backward(grad, self) |
| result: auto_element_wise |
| |
| - name: log2(Tensor self) -> Tensor |
| self: grad / (self.conj() * 0.6931471805599453) |
| result: auto_element_wise |
| |
| - name: logaddexp(Tensor self, Tensor other) -> Tensor |
| self: grad / (1 + exp(other - self)) |
| other: grad / (1 + exp(self - other)) |
| result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p)) |
| |
| - name: logaddexp2(Tensor self, Tensor other) -> Tensor |
| self: grad / (1 + pow(2, other - self)) |
| other: grad / (1 + pow(2, self - other)) |
| result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p)) |
| |
| # Note [Gradient formula for xlogy at x = 0, y <= 0] |
| # x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability |
| # Now, xlogy(0, y) = 0 by definition. |
| # This does not make it differentiable as it's not defined in a neighbourhood of a point |
| # (0, y) when y <= 0. |
| # Now, when a function is non-differentiable, sometimes we return "a relatively sensible value" |
| # In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose |
| # this value to be zero, which is the directional derivative along the line {x = 0}. |
| - name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor |
| self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.) |
| other: grad * self / other |
| result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p |
| |
| - name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor |
| other: grad * self / other |
| result: auto_element_wise |
| |
| - name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor |
| self: "other.toDouble() > 0. |
| ? at::xlogy(grad, other) |
| : at::xlogy(grad, other).masked_fill(self == 0., 0.)" |
| result: auto_element_wise |
| |
| # See Note [Gradient formula for xlogy at x = 0, y <= 0] |
| # Same here but with y <= -1 |
| - name: special_xlog1py(Tensor self, Tensor other) -> Tensor |
| self: at::special_xlog1py(grad, other).masked_fill((self == 0.) & (other <= -1.), 0.) |
| other: grad * self / (other + 1) |
| result: at::special_xlog1py(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1) |
| |
| - name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor |
| other: grad * self / (other + 1) |
| result: auto_element_wise |
| |
| - name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor |
| self: "other.toDouble() > -1. |
| ? at::special_xlog1py(grad, other) |
| : at::special_xlog1py(grad, other).masked_fill(self == 0., 0.)" |
| result: auto_element_wise |
| |
| - name: special_zeta(Tensor self, Tensor other) -> Tensor |
| self: not_implemented("zeta") |
| other: grad * -self * special_zeta(self + 1., other) |
| |
| - name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor |
| other: grad * -self * special_zeta(self.toDouble() + 1., other) |
| |
| - name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor |
| self: not_implemented("zeta") |
| |
| - name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor |
| self: logsumexp_backward(grad, self, result, dim, keepdim) |
| result: logsumexp_jvp(self_p, self_t, dim, keepdim) |
| |
| - name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) |
| self, b: linalg_lstsq_backward(grad, self, b, grad_input_mask) |
| solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t) |
| output_differentiability: [True, False, False, False] |
| |
| - name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) |
| A: lu_factor_ex_backward(grad, LU, pivots, pivot) |
| LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) |
| output_differentiability: [True, False, False] |
| |
| - name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) |
| A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot) |
| L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot)) |
| U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot)) |
| output_differentiability: [False, True, True] |
| |
| - name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor |
| LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint) |
| B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)" |
| result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint) |
| |
| - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) |
| LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) |
| LU_pivots: non_differentiable |
| L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" |
| U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" |
| output_differentiability: [False, True, True] |
| |
| - name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor |
| self: grad.masked_fill(mask, 0) |
| mask: non_differentiable |
| result: self_t.masked_fill(mask, 0) |
| |
| - name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor |
| self: grad.masked_fill(mask, 0) |
| value: masked_fill_backward(grad, mask) |
| mask: non_differentiable |
| result: self_t.masked_fill(mask, value_t) |
| |
| - name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor |
| self: grad.masked_fill(mask, 0) |
| source: masked_scatter_backward(grad, mask, source.sym_sizes()) |
| mask: non_differentiable |
| result: self_t.masked_scatter(mask, source_t) |
| |
| - name: masked_select(Tensor self, Tensor mask) -> Tensor |
| self: masked_select_backward(grad, self, mask) |
| mask: non_differentiable |
| result: auto_linear |
| |
| - name: linalg_matrix_exp(Tensor self) -> Tensor |
| self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true) |
| result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) |
| |
| - name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: max(Tensor self) -> Tensor |
| self: evenly_distribute_backward(grad, self, result) |
| result: evenly_read_jvp(self_t, self_p, result) |
| |
| - name: maximum(Tensor self, Tensor other) -> Tensor |
| self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) |
| other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) |
| result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t) |
| |
| - name: fmax(Tensor self, Tensor other) -> Tensor |
| self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0) |
| other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0) |
| result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) |
| |
| - name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor |
| self: grad.expand_symint(self.sym_sizes()) / self.sym_numel() |
| result: auto_linear |
| |
| - name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
| self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim) |
| result: auto_linear |
| |
| - name: median(Tensor self) -> Tensor |
| self: evenly_distribute_backward(grad, self, result) |
| result: evenly_read_jvp(self_t, self_p, result) |
| |
| - name: nanmedian(Tensor self) -> Tensor |
| self: evenly_distribute_backward(grad, self, result) |
| result: evenly_read_jvp(self_t, self_p, result) |
| |
| # This is in theory incorrect in the following case: |
| # sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value |
| # | at middle position of the |
| # | list between two `b`s. E.g., |
| # | |
| # ^the middle position |
| # The gradient exists and is essentially 0 in this case. |
| # |
| # In case where the middle position is at the boundary of `b` range, e.g., |
| # sorted list: [..., a, b, b, ..., b, b, c, ...] |
| # | |
| # ^the middle position |
| # The backward implementation is correct in the sense that it returns the |
| # subgradient on one side. |
| - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: min(Tensor self) -> Tensor |
| self: evenly_distribute_backward(grad, self, result) |
| result: evenly_read_jvp(self_t, self_p, result) |
| |
| - name: minimum(Tensor self, Tensor other) -> Tensor |
| self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) |
| other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) |
| result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t) |
| |
| - name: fmin(Tensor self, Tensor other) -> Tensor |
| self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0) |
| other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0) |
| result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t) |
| |
| - name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor |
| self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) |
| result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) |
| |
| - name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor |
| self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) |
| result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) |
| |
| - name: mm(Tensor self, Tensor mat2) -> Tensor |
| self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1) |
| mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) |
| result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) |
| |
| - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) |
| values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) |
| |
| - name: mul.Tensor(Tensor self, Tensor other) -> Tensor |
| self: mul_tensor_backward(grad, other, self.scalar_type()) |
| other: mul_tensor_backward(grad, self, other.scalar_type()) |
| result: other_t * self_p + self_t * other_p |
| |
| - name: mul.Scalar(Tensor self, Scalar other) -> Tensor |
| self: mul_tensor_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type()) |
| result: self_t * other |
| |
| - name: mv(Tensor self, Tensor vec) -> Tensor |
| self: grad.ger(vec.conj()) |
| vec: self.conj().t().mv(grad) |
| result: mv(self_t, vec_p) + mv(self_p, vec_t) |
| |
| - name: mvlgamma(Tensor self, int p) -> Tensor |
| self: mvlgamma_backward(grad, self, p) |
| result: auto_element_wise |
| |
| - name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor |
| self: grad * at::isfinite(self) |
| result: auto_element_wise |
| |
| - name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) |
| |
| - name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) |
| |
| - name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps) |
| |
| - name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) |
| input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) |
| save_mean: not_implemented("native_batch_norm_backward save_mean") |
| save_invstd: not_implemented("native_batch_norm_backward save_invstd") |
| |
| - name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) |
| |
| - name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) |
| input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) |
| bias: Tensor() |
| mean: not_implemented("native_layer_norm_backward mean") |
| rstd: not_implemented("native_layer_norm_backward rstd") |
| |
| - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? grads[0].suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())" |
| result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) |
| result1: group_norm_mean_jvp(input_t, result1, group) |
| result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) |
| |
| - name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) |
| self: zeros_like(self) |
| result: self_t.zero_() |
| |
| - name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) |
| self: zeros_like(self) |
| other: zeros_like(other) |
| result: self_t.zero_() |
| |
| - name: neg(Tensor self) -> Tensor |
| self: grad.neg() |
| result: auto_element_wise |
| |
| - name: nextafter(Tensor self, Tensor other) -> Tensor |
| self: not_implemented("nextafter") |
| other: not_implemented("nextafter") |
| |
| - name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor |
| self: norm_backward(grad, self, p, result) |
| result: norm_jvp(self_p, self_t, p, result) |
| |
| - name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor |
| self: norm_backward(grad, self, p, result, dim, keepdim) |
| result: norm_jvp(self_p, self_t, p, result, dim, keepdim) |
| |
| - name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor |
| self: norm_backward(grad, self.to(grad.scalar_type()), p, result) |
| result: norm_jvp(self_p, self_t, p, result) |
| |
| - name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) |
| result: norm_jvp(self_p, self_t, p, result, dim, keepdim) |
| |
| - name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
| self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim) |
| result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim) |
| |
| - name: _pdist_forward(Tensor self, float p=2) -> Tensor |
| self: _pdist_backward(grad, self, p, result) |
| |
| - name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor |
| grad: not_implemented("_pdist_backward") |
| self: not_implemented("_pdist_backward") |
| pdist: not_implemented("_pdist_backward") |
| |
| - name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor |
| x1, x2: _euclidean_dist_backward(grad, x1, x2, result) |
| |
| - name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor |
| x1: _cdist_backward(grad.contiguous(), x1, x2, p, result) |
| x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous()) |
| |
| - name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor |
| grad: not_implemented("_cdist_backward") |
| x1: not_implemented("_cdist_backward") |
| x2: not_implemented("_cdist_backward") |
| cdist: not_implemented("_cdist_backward") |
| |
| - name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor |
| mean: at::zeros_symint(mean.sym_sizes(), grad.options()) |
| result: auto_element_wise |
| |
| - name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor |
| std: at::zeros_symint(std.sym_sizes(), grad.options()) |
| result: auto_element_wise |
| |
| - name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor |
| mean: at::zeros_symint(mean.sym_sizes(), grad.options()) |
| std: at::zeros_symint(std.sym_sizes(), grad.options()) |
| result: zeros_like(mean_t) |
| |
| - name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor |
| input, tau: householder_product_backward(grad, result, input, tau) |
| result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p) |
| |
| - name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor |
| self, input2, input3: ormqr_backward(grad, result, self, input2, input3, left, transpose, grad_input_mask) |
| |
| - name: permute(Tensor(a) self, int[] dims) -> Tensor(a) |
| self: permute_backwards(grad, dims) |
| result: auto_linear |
| |
| - name: poisson(Tensor self, Generator? generator=None) -> Tensor |
| self: zeros_like(self) |
| result: auto_element_wise |
| |
| - name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor |
| self: pow_backward(grad, self, exponent) |
| result: auto_element_wise |
| |
| - name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor |
| self: pow_backward_self(grad, self, exponent) |
| exponent: pow_backward_exponent(grad, self, exponent, result) |
| result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj() |
| |
| - name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor |
| exponent: pow_backward_exponent(grad, self, exponent, result) |
| result: auto_element_wise |
| |
| - name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor |
| self: prod_backward(grad, self.to(grad.scalar_type()), result) |
| result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj() |
| |
| - name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
| self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) |
| result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() |
| |
| - name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor |
| self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" |
| index: non_differentiable |
| source: grad.take(index).reshape_as(source) |
| result: self_t.put(index, source_t, accumulate) |
| |
| - name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) |
| A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) |
| Q, R: linalg_qr_jvp(A_t, Q, R, mode) |
| |
| - name: rad2deg(Tensor self) -> Tensor |
| self: rad2deg_backward(grad) |
| result: auto_element_wise |
| |
| - name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: reciprocal(Tensor self) -> Tensor |
| self: -grad * (result * result).conj() |
| result: auto_element_wise |
| |
| - name: remainder.Scalar(Tensor self, Scalar other) -> Tensor |
| self: grad |
| result: auto_element_wise |
| |
| - name: remainder.Tensor(Tensor self, Tensor other) -> Tensor |
| self: grad |
| other: -grad * self.div(other, /*rounding_mode=*/"floor") |
| result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor") |
| |
| - name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor |
| self: renorm_backward(grad, self, p, dim, maxnorm) |
| |
| - name: repeat(Tensor self, SymInt[] repeats) -> Tensor |
| self: repeat_backward(grad, repeats, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: special_entr(Tensor self) -> Tensor |
| self: grad * (-(1 + self.log())) |
| result: auto_element_wise |
| |
| - name: special_ndtri(Tensor self) -> Tensor |
| self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() |
| result: auto_element_wise |
| |
| - name: special_log_ndtr(Tensor self) -> Tensor |
| self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp() |
| result: auto_element_wise |
| |
| # [Note: Sometimes view derivatives] |
| # The following situation applies to other operations as well. |
| # TODO: This note is only referenced once by to_dense. Make this |
| # more generic if it's been referenced more than once. |
| # |
| # DO NOT define a backward for reshape! |
| # reshape is special in that it sometimes returns a view, and sometimes not. |
| # Defining a backward will make codegen spit out the forward call as |
| # as_variable(baseType->reshape(self)), |
| # making it impossible (hard) to detect when it is actually a view. |
| # - name: reshape(Tensor self, IntArrayRef shape) |
| |
| - name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) |
| self: grad.reshape_symint(self.sym_sizes()) |
| result: auto_linear |
| |
| - name: round(Tensor self) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: round.decimals(Tensor self, *, int decimals) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: rsqrt(Tensor self) -> Tensor |
| self: -0.5 * grad * result.pow(3).conj() |
| result: auto_element_wise |
| |
| - name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor |
| self: grad.scatter(dim, index, 0) |
| index: non_differentiable |
| src: grad.gather(dim, index) |
| result: self_t.scatter(dim, index, src_t) |
| |
| - name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor |
| self: grad.scatter(dim, index, 0) |
| index: non_differentiable |
| result: self_t.scatter(dim, index, 0) |
| |
| - name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor |
| self: grad |
| index: non_differentiable |
| src: grad.gather(dim, index) |
| result: scatter_add(self_t, dim, index, src_t) |
| |
| - name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) |
| dispatch: |
| Default: |
| self: select_backward_symint(grad, self.sym_sizes(), dim, index) |
| result: auto_linear |
| AutogradNestedTensor: |
| self: _nested_select_backward_symint(grad, self, dim, index) |
| |
| - name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor |
| grad_output: grad.select_symint(dim, index) |
| result: auto_linear |
| |
| - name: sigmoid(Tensor self) -> Tensor |
| self: sigmoid_backward(grad, result) |
| result: auto_element_wise |
| |
| - name: logit(Tensor self, float? eps=None) -> Tensor |
| self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)" |
| result: auto_element_wise |
| |
| - name: sign(Tensor self) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| - name: sgn(Tensor self) -> Tensor |
| self: sgn_backward(self, grad, result) |
| # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric) |
| # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian |
| # auto_element_wise has a name that's a bit deceiving in the complex case |
| result: sgn_backward(self_p, self_t, result) |
| |
| - name: sin(Tensor self) -> Tensor |
| self: grad * self.cos().conj() |
| result: auto_element_wise |
| |
| - name: sinc(Tensor self) -> Tensor |
| self: sinc_backward(grad, self) |
| result: auto_element_wise |
| |
| - name: sinh(Tensor self) -> Tensor |
| self: grad * self.cosh().conj() |
| result: auto_element_wise |
| |
| - name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) |
| self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step) |
| result: auto_linear |
| |
| - name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor |
| grad_output: grad.slice_symint(dim, start, end, step) |
| result: auto_linear |
| |
| - name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor |
| self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step) |
| src: grad.slice_symint(dim, start, end, step) |
| result: auto_linear |
| |
| - name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor |
| self: select_scatter_symint(grad, zeros_like(src), dim, index) |
| src: grad.select_symint(dim, index) |
| result: auto_linear |
| |
| - name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor |
| self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2) |
| src: grad.diagonal(offset, dim1, dim2) |
| result: auto_linear |
| |
| - name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor |
| self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset) |
| # See Note [as_strided_scatter backward support] |
| src: grad.contiguous().as_strided_symint(size, stride, storage_offset) |
| result: auto_linear |
| |
| - name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) |
| A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1]) |
| result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())" |
| output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user |
| |
| - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) |
| output_differentiability: [True, False] |
| values: gather_with_keepdimed_indices(self_t, dim, indices, true) |
| |
| - name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) |
| output_differentiability: [True, False] |
| values: gather_with_keepdimed_indices(self_t, dim, indices, true) |
| |
| - name: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] |
| self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) |
| result: auto_linear |
| |
| - name: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] |
| self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) |
| result: auto_linear |
| |
| - name: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] |
| self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) |
| result: auto_linear |
| |
| - name: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] |
| self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) |
| result: auto_linear |
| |
| - name: sqrt(Tensor self) -> Tensor |
| self: grad / (2 * result.conj()) |
| result: auto_element_wise |
| |
| - name: squeeze(Tensor(a) self) -> Tensor(a) |
| self: unsqueeze_to(grad, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) |
| dispatch: |
| Default: |
| self: unsqueeze_to(grad, dim, self.sym_sizes()) |
| result: auto_linear |
| AutogradNestedTensor: |
| self: grad.unsqueeze(dim) |
| |
| - name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) |
| dispatch: |
| Default: |
| self: unsqueeze_to(grad, dim, self.sym_sizes()) |
| result: auto_linear |
| AutogradNestedTensor: |
| self: unsqueeze_multiple(grad, dim, self.dim()) |
| |
| - name: squeeze_(Tensor(a!) self) -> Tensor(a!) |
| self: unsqueeze_to(grad, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) |
| self: unsqueeze_to(grad, dim, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) |
| self: unsqueeze_to(grad, dim, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor |
| self: std_backward(result, grad, self, dim, correction, keepdim) |
| # pointwise (variance) + sum + sqrt |
| result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) |
| |
| - name: std_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) |
| self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) |
| result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) |
| # linear |
| result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) |
| |
| - name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj())) |
| result: self_t - maybe_multiply(other_t, alpha) |
| |
| - name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), grad) |
| result: auto_element_wise |
| |
| - name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) |
| other: handle_r_to_c(other.scalar_type(), grad) |
| result: -maybe_multiply(self_t, alpha) + other_t |
| |
| - name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor |
| self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) |
| result: auto_element_wise |
| |
| - name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor |
| self: grad.expand_symint(self.sym_sizes()) |
| result: auto_linear |
| |
| - name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
| dispatch: |
| Default: |
| self: sum_backward(grad, self.sym_sizes(), dim, keepdim) |
| result: auto_linear |
| AutogradNestedTensor: |
| # TODO: replace this function once semantics for nested tensor expand have been settled on |
| self: _nested_sum_backward(grad, self, dim, keepdim) |
| |
| - name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor |
| self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) |
| result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype) |
| |
| # We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here |
| - name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) |
| A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U, |
| grad_S, |
| full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh, |
| full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U, |
| S, |
| full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)" |
| U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices) |
| |
| - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) |
| self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors) |
| |
| - name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) |
| A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true) |
| eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true) |
| |
| - name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) |
| self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false)) |
| eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false) |
| |
| - name: t(Tensor(a) self) -> Tensor(a) |
| self: grad.t() |
| result: auto_linear |
| |
| - name: t_(Tensor(a!) self) -> Tensor(a!) |
| self: grad.t() |
| result: auto_linear |
| |
| - name: one_hot(Tensor self, int num_classes=-1) -> Tensor |
| self: non_differentiable |
| |
| - name: flip(Tensor self, int[] dims) -> Tensor |
| self: grad.flip(dims) |
| result: auto_linear |
| |
| - name: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor |
| self: grad.roll(fmap(reverse_list(shifts), [](int64_t i){return -i;}), reverse_list(dims)) |
| result: auto_linear |
| |
| - name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor |
| self: grad.rot90(-k, dims) |
| result: auto_linear |
| |
| - name: take(Tensor self, Tensor index) -> Tensor |
| self: take_backward(grad, self, index) |
| index: non_differentiable |
| result: auto_linear |
| |
| - name: tan(Tensor self) -> Tensor |
| self: grad * (1 + result.pow(2)).conj() |
| result: auto_element_wise |
| |
| - name: tanh(Tensor self) -> Tensor |
| self: tanh_backward(grad, result) |
| result: auto_element_wise |
| |
| - name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) |
| self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) |
| output_differentiability: [True, False] |
| values: gather(self_t, dim, indices) |
| |
| - name: trace(Tensor self) -> Tensor |
| self: trace_backward_symint(grad, self.sym_sizes()) |
| result: auto_linear |
| |
| - name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) |
| self: grad.transpose(dim0, dim1) |
| result: auto_linear |
| |
| - name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) |
| self: grad.transpose(dim0, dim1) |
| result: auto_linear |
| |
| - name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) |
| self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask) |
| solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular) |
| cloned_coefficient: A_t |
| |
| - name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor |
| self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) |
| result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) |
| |
| - name: tril(Tensor self, int diagonal=0) -> Tensor |
| self: grad.tril(diagonal) |
| result: auto_linear |
| |
| - name: triu(Tensor self, int diagonal=0) -> Tensor |
| self: grad.triu(diagonal) |
| result: auto_linear |
| |
| - name: trunc(Tensor self) -> Tensor |
| self: zeros_like(grad) |
| result: auto_element_wise |
| |
| # DO NOT define a backward for to_dense |
| # See [Note: Sometimes view derivatives] |
| # - name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor |
| # |
| - name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor |
| self: to_dense_backward(grad, self) |
| |
| - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor |
| self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) |
| |
| - name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor |
| self: to_mkldnn_backward(grad, self) |
| |
| - name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) |
| self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step) |
| result: auto_linear |
| |
| - name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor |
| grad_in: grad.unfold(dim, size, step) |
| result: auto_linear |
| |
| - name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: self_t.zero_() |
| |
| - name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) |
| output_differentiability: [True, False] |
| self: not_implemented("_unique") |
| |
| - name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) |
| output_differentiability: [True, False, False] |
| self: not_implemented("unique_dim") |
| |
| - name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) |
| output_differentiability: [True, False, False] |
| self: not_implemented("unique_consecutive") |
| |
| - name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) |
| output_differentiability: [True, False, False] |
| self: not_implemented("unique_dim_consecutive") |
| |
| - name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) |
| output_differentiability: [True, False, False] |
| self: not_implemented("_unique2") |
| |
| - name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor |
| self: grad.reshape_symint(self.sym_sizes()) |
| result: auto_linear |
| |
| - name: lift(Tensor self) -> Tensor |
| self: grad |
| result: auto_linear |
| |
| - name: lift_fresh(Tensor(a) self) -> Tensor(a) |
| self: grad |
| result: auto_linear |
| |
| - name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) |
| self: grad.squeeze(dim) |
| result: auto_linear |
| |
| - name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) |
| self: grad.squeeze(dim) |
| result: auto_linear |
| |
| - name: var.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor |
| self: var_backward(grad, self, dim, correction, keepdim) |
| # pointwise + sum |
| result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) |
| |
| - name: var_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) |
| self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) |
| result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) |
| # linear |
| result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) |
| |
| - name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) |
| dispatch: |
| Default: |
| self: grad.reshape_symint(self.sym_sizes()) |
| result: auto_linear |
| AutogradNestedTensor: |
| self: grad.reshape_as(self) |
| result: auto_linear |
| |
| - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) |
| output_differentiability: [False] |
| |
| - name: view_as_real(Tensor(a) self) -> Tensor(a) |
| self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 |
| result: at::view_as_real(self_t) |
| |
| - name: view_as_complex(Tensor(a) self) -> Tensor(a) |
| self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] |
| result: at::view_as_complex(self_t) |
| |
| - name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor |
| condition: non_differentiable |
| self: where(condition, grad, 0) |
| other: where(condition, 0, grad) |
| result: where(condition, self_t, other_t) |
| |
| # weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen |
| # to be running backward with create_graph=True, fall back to a backward function that uses |
| # differentiable ops. |
| - name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) |
| v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple<Tensor, Tensor>()" |
| |
| - name: zero_(Tensor(a!) self) -> Tensor(a!) |
| self: zeros_like(grad) |
| result: auto_linear |
| |
| - name: sparse_mask(Tensor self, Tensor mask) -> Tensor |
| self: grad.to_dense().sparse_mask(mask).to_dense() |
| mask: non_differentiable |
| |
| - name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor |
| values: sparse_constructor_values_backward(grad, indices) |
| |
| - name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor |
| self: at::_sparse_sum_backward(grad, self, dim) |
| |
| - name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor |
| self: grad * _standard_gamma_grad(self, result) |
| |
| - name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor |
| self: not_implemented("_standard_gamma_grad") |
| |
| - name: values(Tensor(a) self) -> Tensor(a) |
| dispatch: |
| Default: |
| self: at::_sparse_coo_tensor_unsafe_symint(self.indices(), grad, self.sym_sizes())._coalesced_(true) |
| AutogradNestedTensor: |
| self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_offsets()) |
| |
| # Why is _values() not differentiable? |
| # See NOTE [ Sparse: autograd and API ] |
| - name: _values(Tensor(a) self) -> Tensor(a) |
| output_differentiability: [False] |
| |
| # NN |
| - name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor |
| i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, grad_input_mask) |
| result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + |
| _trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + |
| _trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)" |
| |
| - name: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor |
| self: constant_pad_nd_backward(grad, pad) |
| result: constant_pad_nd_symint(self_t, pad, 0) |
| |
| - name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor |
| self: binary_cross_entropy_backward(grad, self, target, weight, reduction) |
| target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction) |
| result: "apply_loss_reduction( |
| binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None) |
| + binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None), |
| reduction)" |
| |
| - name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor |
| self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction) |
| target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction) |
| grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction) |
| result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction) |
| + binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction) |
| + binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)" |
| |
| - name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor |
| self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction) |
| target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) |
| result: "apply_loss_reduction( |
| binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None) |
| + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None), |
| reduction)" |
| |
| - name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor |
| indices: non_differentiable |
| weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse) |
| result: auto_linear |
| |
| - name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor |
| grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx) |
| indices: non_differentiable |
| result: auto_linear |
| |
| - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) |
| indices: non_differentiable |
| offsets: non_differentiable |
| weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) |
| per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx) |
| |
| - name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor |
| indices: non_differentiable |
| offset2bag: non_differentiable |
| bag_size: non_differentiable |
| maximum_indices: non_differentiable |
| |
| - name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) |
| indices: non_differentiable |
| self: not_implemented("embedding_renorm") |
| |
| - name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor |
| self: mse_loss_backward(grad, self, target, reduction) |
| target: mse_loss_backward(grad, target, self, reduction) |
| result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction) |
| |
| - name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor |
| self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction) |
| target: non_differentiable |
| |
| - name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) |
| self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target) |
| target: non_differentiable |
| |
| - name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) |
| self: nll_loss_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) |
| target: non_differentiable |
| output: std::get<0>(nll_loss_forward_symint(self_t, target, weight, reduction, ignore_index)) |
| |
| - name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) |
| self: nll_loss2d_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) |
| target: non_differentiable |
| output: std::get<0>(nll_loss2d_forward_symint(self_t, target, weight, reduction, ignore_index)) |
| |
| - name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor |
| self: smooth_l1_loss_backward(grad, self, target, reduction, beta) |
| target: smooth_l1_loss_backward(grad, target, self, reduction, beta) |
| result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction) |
| |
| - name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor |
| self: huber_loss_backward(grad, self, target, reduction, delta) |
| target: huber_loss_backward(grad, target, self, reduction, delta) |
| result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction) |
| |
| - name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor |
| self: soft_margin_loss_backward(grad, self, target, reduction) |
| result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction) |
| |
| - name: relu(Tensor self) -> Tensor |
| self: threshold_backward(grad, result, 0) |
| result: auto_element_wise |
| |
| - name: silu(Tensor self) -> Tensor |
| self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" |
| result: auto_element_wise |
| |
| - name: mish(Tensor self) -> Tensor |
| self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" |
| result: auto_element_wise |
| |
| - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor |
| self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) |
| result: auto_element_wise |
| |
| - name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) |
| self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) |
| result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result)) |
| |
| - name: celu(Tensor self, Scalar alpha=1.0) -> Tensor |
| self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) |
| result: auto_element_wise |
| |
| - name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) |
| self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) |
| result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)) |
| |
| - name: gelu(Tensor self, *, str approximate='none') -> Tensor |
| self: gelu_backward(grad, self, approximate) |
| result: auto_element_wise |
| |
| - name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor |
| grad_output: gelu_backward(grad, self, approximate) |
| self: gelu_double_backward(grad, grad_output, self, approximate) |
| result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate) |
| |
| - name: glu(Tensor self, int dim=-1) -> Tensor |
| # TODO: glu_backward can benefit from forward result, |
| # and forward ad/forward over reverse ad for that matter |
| self: glu_backward(grad, self, dim) |
| result: glu_jvp(result, self_p, self_t, dim) |
| |
| - name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor |
| self: hardshrink_backward(grad, self, lambd) |
| result: auto_element_wise |
| |
| - name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor |
| grad_out: hardshrink_backward(grad, self, lambd) |
| self: zeros_like(grad) |
| result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result)) |
| |
| - name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor |
| self: hardtanh_backward(grad, self, min_val, max_val) |
| result: auto_element_wise |
| |
| - name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor |
| self: leaky_relu_backward(grad, self, negative_slope, false) |
| result: auto_element_wise |
| |
| - name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) |
| self: leaky_relu_backward(grad, result, negative_slope, true) |
| result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj()) |
| |
| - name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) |
| self: log_sigmoid_backward(grad, self, buffer) |
| output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj() |
| |
| - name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor |
| self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) |
| result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true) |
| |
| - name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor |
| self: _sparse_log_softmax_backward_data(grad, result, dim, self) |
| |
| - name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor |
| self: _masked_softmax_backward(grad, result, mask, dim) |
| mask: non_differentiable |
| |
| - name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor |
| self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple<Tensor, Tensor>()" |
| result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p) |
| |
| - name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) |
| grad_output: "grads[0].defined() ? |
| (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self) |
| : at::where(self >= 0, grads[0], grads[0] * weight)) |
| : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)" |
| self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)" |
| weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)" |
| result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) |
| result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) |
| |
| - name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor |
| self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) |
| result: auto_element_wise |
| |
| - name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) |
| self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) |
| |
| - name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor |
| self: _softmax_backward_data(grad, result, dim, self.scalar_type()) |
| result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true)) |
| |
| - name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor |
| self: _sparse_softmax_backward_data(grad, result, dim, self) |
| |
| - name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor |
| self: sparse_sparse_matmul_backward(grad, self, other, 0) |
| other: sparse_sparse_matmul_backward(grad, self, other, 1) |
| |
| - name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor |
| self: softplus_backward(grad, self, beta, threshold) |
| result: auto_element_wise |
| |
| - name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor |
| self: softshrink_backward(grad, self, lambd) |
| result: auto_element_wise |
| |
| - name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor |
| self: threshold_backward(grad, self, threshold) |
| result: auto_element_wise |
| |
| - name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) |
| self: threshold_backward(grad, self, threshold) |
| result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj()) |
| |
| - name: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor |
| self: reflection_pad1d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor |
| self: reflection_pad2d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor |
| self: reflection_pad3d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor |
| self: replication_pad1d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor |
| self: replication_pad2d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor |
| self: replication_pad3d_backward_symint(grad, self, padding) |
| result: auto_linear |
| |
| - name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor |
| self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales) |
| result: auto_linear |
| |
| - name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor |
| self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor |
| self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales) |
| result: auto_linear |
| |
| - name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor |
| self: pixel_unshuffle(grad, upscale_factor) |
| result: auto_linear |
| |
| - name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor |
| self: pixel_shuffle(grad, downscale_factor) |
| result: auto_linear |
| |
| - name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor |
| self: _adaptive_avg_pool2d_backward(grad, self) |
| result: auto_linear |
| |
| - name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor |
| self: _adaptive_avg_pool3d_backward(grad, self) |
| result: auto_linear |
| |
| - name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) |
| self: adaptive_max_pool2d_backward(grad, self, result1) |
| result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) |
| self: adaptive_max_pool3d_backward(grad, self, result1) |
| result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor |
| self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| result: auto_linear |
| |
| - name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor |
| self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| result: auto_linear |
| |
| - name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) |
| self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1) |
| result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) |
| self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1) |
| result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor |
| input, weight, bias: linear_backward(input, grad, weight, grad_input_mask) |
| |
| #mps |
| - name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor |
| self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) |
| |
| - name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor |
| self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) |
| |
| - name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) |
| self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) |
| result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) |
| self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) |
| result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) |
| output_differentiability: [True, False] |
| |
| - name: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor |
| self: max_pool_double_backward(grad, indices, 2) |
| indices: non_differentiable |
| result: auto_linear |
| |
| - name: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor |
| self: max_pool_double_backward(grad, indices, 3) |
| indices: non_differentiable |
| result: auto_linear |
| |
| - name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor |
| input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups) |
| |
| # TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution. |
| # Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context |
| # by convolution_backward instead of being passed along from the forward pass. |
| - name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor |
| input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) |
| |
| - name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) |
| grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) |
| result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) |
| result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) |
| result2: convolution_backward_jvp_grad_bias(grad_output_t, result2) |
| |
| - name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor |
| input, weight, bias: "grad.defined() ? convolution_backward_overrideable(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) |
| grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) |
| |
| - name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor |
| self, weight, bias: "grad.defined() ? _slow_conv2d_backward(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) |
| grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) |
| |
| - name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, int[2] dilation=1) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, int[3] dilation=1) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor |
| self: im2col(grad, kernel_size, dilation, padding, stride) |
| result: auto_linear |
| |
| - name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor |
| self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride) |
| result: auto_linear |
| |
| - name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor |
| grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)}) |
| self: zeros_like(self) |
| result: _adaptive_avg_pool2d_backward(grad_output_t, self_p) |
| |
| - name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor |
| grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) }) |
| self: zeros_like(self) |
| result: _adaptive_avg_pool3d_backward(grad_output_t, self_p) |
| |
| - name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 2) |
| self: zeros_like(self) |
| result: auto_linear |
| |
| - name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 3) |
| self: zeros_like(self) |
| result: auto_linear |
| |
| - name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor |
| grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| self: zeros_like(self) |
| result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| |
| - name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor |
| grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| self: zeros_like(self) |
| result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) |
| |
| - name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor |
| grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result) |
| self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result) |
| result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p) |
| |
| - name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 2) |
| self: zeros_like(self) |
| result: auto_linear |
| |
| - name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 3) |
| self: zeros_like(self) |
| result: auto_linear |
| |
| - name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor |
| grad_output: glu_double_backward_grad_output(grad, self, dim) |
| self: glu_double_backward(grad, grad_output, self, dim) |
| result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim) |
| |
| - name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor |
| grad_output: hardtanh_backward(grad, self, min_val, max_val) |
| self: zeros_like(grad) |
| result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result)) |
| |
| - name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor |
| grad_output: log_sigmoid_backward(grad, self, buffer) |
| self: log_sigmoid_double_backward(grad * grad_output, self) |
| |
| - name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor |
| grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true) |
| output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype()) |
| |
| - name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor |
| # self_is_result is always false here since double backward call is an out-of-place call, self is input itself |
| grad_output: leaky_relu_backward(grad, self, negative_slope, false) |
| self: zeros_like(grad) |
| # leaky_relu_backward(grad_output, self, negative_slope, false) |
| # computes grad_output * at::where(self_p > 0, 1, negative_slope) |
| # so the jvp formula is the following: |
| # grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope); |
| # |
| # leaky_relu_backward(grad_output, result, negative_slope, true) |
| # computes grad_output * at::where(result > 0, 1, negative_slope) |
| # under the assumption that `negative_slope` is positive (otherwise, |
| # it is not possible to compute the gradient). |
| # |
| # so the jvp formula is the following: |
| # grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope); |
| # with the assumption that negative_slope is positive. |
| # |
| # Combined together that results in the following optimized kernel which |
| # also checks the assumption that negative_slope is positive when self_is_result |
| # is True: |
| result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result) |
| |
| - name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 2) |
| self: zeros_like(self) |
| indices: non_differentiable |
| result: auto_linear |
| |
| - name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor |
| grad_output: max_pool_double_backward(grad, indices, 3) |
| self: zeros_like(self) |
| indices: non_differentiable |
| result: auto_linear |
| |
| - name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor |
| grad_output: mse_loss_backward(grad, self, target, reduction) |
| self: mse_loss_double_backward(grad * grad_output, self, reduction) |
| target: -mse_loss_double_backward(grad * grad_output, target, reduction) |
| result: " mse_loss_double_backward(self_t * grad_output_p, self_p, reduction) |
| - mse_loss_double_backward(target_t * grad_output_p, target_p, reduction) |
| + mse_loss_backward(grad_output_t, self_p, target_p, reduction) |
| " |
| |
| - name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor |
| grad_output: nll_loss_symint(grad, target, weight, reduction, ignore_index) |
| self: zeros_like(grad) |
| target: non_differentiable |
| |
| - name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor |
| grad_output: nll_loss2d_symint(grad, target, weight, reduction, ignore_index) |
| self: zeros_like(grad) |
| target: non_differentiable |
| |
| - name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor |
| # self_is_result is always false here since double backward call is an out-of-place call, self is input itself |
| grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) |
| self: zeros_like(grad) |
| result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false) |
| |
| - name: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor |
| grad_output: reflection_pad1d_symint(grad, padding) |
| self: zeros_like(self) |
| result: reflection_pad1d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor |
| grad_output: reflection_pad2d_symint(grad, padding) |
| self: zeros_like(self) |
| result: reflection_pad2d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor |
| grad_output: reflection_pad3d_symint(grad, padding) |
| self: zeros_like(self) |
| result: reflection_pad3d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor |
| grad_output: replication_pad1d_symint(grad, padding) |
| self: zeros_like(self) |
| result: replication_pad1d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor |
| grad_output: replication_pad2d_symint(grad, padding) |
| self: zeros_like(self) |
| result: replication_pad2d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor |
| grad_output: replication_pad3d_symint(grad, padding) |
| self: zeros_like(self) |
| result: replication_pad3d_backward_symint(grad_output_t, self_p, padding) |
| |
| - name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| self: maybe_multiply(grad, beta.conj()) |
| mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj()) |
| mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj()) |
| |
| - name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor |
| grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta) |
| self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) |
| target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) |
| result: " smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta) |
| - smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta) |
| + smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta) |
| " |
| |
| - name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor |
| grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta) |
| self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) |
| target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) |
| |
| - name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor |
| grad_output: softplus_backward(grad, self, beta, threshold) |
| self: softplus_double_backward(grad * grad_output, self, beta, threshold) |
| result: "softplus_backward(grad_output_t, self_p, beta, threshold) |
| + softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)" |
| |
| - name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor |
| grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype) |
| output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype()) |
| |
| - name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor |
| grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) |
| self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction) |
| |
| - name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor |
| grad_output: softshrink_backward(grad, self, lambd) |
| self: zeros_like(grad) |
| result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result)) |
| |
| - name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor |
| grad_output: threshold_backward(grad, self, threshold) |
| self: zeros_like(grad) |
| result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold) |
| |
| - name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor |
| grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales) |
| result: auto_linear |
| |
| - name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor |
| grad_output: upsample_nearest1d_symint(grad, output_size, scales) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor |
| grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales) |
| result: auto_linear |
| |
| - name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor |
| grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w) |
| result: auto_linear |
| |
| - name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor |
| grad_output: sigmoid_backward(grad, output.conj()) |
| output: grad.conj() * grad_output * (-2 * output.conj() + 1) |
| result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1) |
| |
| - name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor |
| grad_output: tanh_backward(grad, output.conj()) |
| output: grad.conj() * (-2 * output.conj() * grad_output) |
| result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p) |
| |
| # cudnn |
| - name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) |
| log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) |
| |
| - name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) |
| log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) |
| |
| - name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor |
| self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})" |
| |
| - name: _mps_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor |
| self, weight: "grad.defined() ? mps_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor>()" |
| |
| - name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor |
| self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector<int64_t>(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})" |
| |
| - name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output |
| self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple<Tensor, Tensor>()" |
| |
| - name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid |
| theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) |
| |
| # NB: Why is the backwards here so complicated? CuDNN cannot be used to compute |
| # backward in evaluation mode, because the math for backward in evaluation mode |
| # is different (since the forward math is different), and CuDNN does not support |
| # it. And in any case, you shouldn't be using this bn in evaluation mode, |
| # because it should be merged into the previous convolution (left for future |
| # work.) |
| # NB2: The quotes around the gradient are needed to appease YAML parsing rules. |
| - name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()" |
| result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) |
| |
| # HACK: save_mean and save_var are going to be passed in as |
| # requires_grad variables (even though we'll never backprop through |
| # them) so we need to prevent the unpacking from triggering an error. |
| - name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) |
| save_mean: not_implemented("cudnn_batch_norm_backward save_mean") |
| save_var: not_implemented("cudnn_batch_norm_backward save_var") |
| reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace") |
| input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) |
| |
| # nnpack |
| |
| - name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, int[2] stride=1) -> Tensor |
| # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. |
| input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| #LSTM MPS |
| - name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor) |
| output_differentiability: [True, True, True, False, False] |
| input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)" |
| |
| - name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) |
| |
| |
| |
| # Only frst three of _cudnn_rnn outputs can have gradients. |
| # _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) |
| - name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) |
| dropout_state: non_differentiable |
| output_differentiability: [True, True, True, False, False] |
| input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" |
| |
| - name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) |
| dropout_state: non_differentiable |
| input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) |
| |
| # miopen |
| |
| - name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) |
| input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) |
| save_mean: not_implemented("miopen_batch_norm_backward save_mean") |
| save_var: not_implemented("miopen_batch_norm_backward save_var") |
| input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) |
| |
| - name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) |
| dropout_state: non_differentiable |
| output_differentiability: [True, True, True, False, False] |
| input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" |
| |
| - name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) |
| dropout_state: non_differentiable |
| |
| - name: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) |
| output_differentiability: [True, True, True, False] |
| input, weight0, weight1, weight2, weight3, hx_, cx_: "mkldnn_rnn_layer_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3)" |
| |
| - name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) |
| |
| # mkldnn |
| - name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups) -> Tensor |
| self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()" |
| |
| - name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor |
| self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) |
| |
| - name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor |
| self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) |
| |
| - name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor |
| self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) |
| |
| - name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor |
| self: mkldnn_adaptive_avg_pool2d_backward(grad, self) |
| |
| - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor |
| self: grad.reshape_symint(self.sym_sizes()) |
| |
| # NestedTensor |
| - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
| list: "grad.defined()? at::unbind(grad) : std::vector<Tensor>(list.size())" |
| |
| - name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor |
| t: grad.to_padded_tensor_symint(0, t.sym_sizes()) |
| mask: non_differentiable |
| |
| - name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor |
| padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213) |
| cpu_nested_shape_example: non_differentiable |
| |
| - name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor |
| self: at::_nested_from_padded(grad, self._nested_tensor_size()) |
| padding: non_differentiable |
| |
| - name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, int[] offsets) -> Tensor(a) |
| self: grad.values() |
| nested_size: non_differentiable |
| nested_strides: non_differentiable |
| |
| # Transformers |
| - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) |
| output_differentiability: [True, False] |
| query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal, at::_chunk_grad_outputs_efficient_attention(query, key, value, is_causal)) |
| |
| - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) |
| output_differentiability: [True, False] |
| query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal, at::_chunk_grad_outputs_efficient_attention(query, key, value, causal)) |
| |
| # fft |
| - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor |
| self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) |
| result: auto_linear |
| |
| - name: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor |
| self: fft_c2r_backward(grad, dim, normalization) |
| result: auto_linear |
| |
| - name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor |
| self: _fft_c2c_symint(grad, dim, normalization, !forward) |
| result: auto_linear |
| |
| - name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] |
| self: unbind_backward(grads, dim) |
| result: auto_linear |
| |
| - name: stack(Tensor[] tensors, int dim=0) -> Tensor |
| tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors)) |
| result: stack_jvp(tensors, dim) |
| |
| # fused RNN kernels |
| |
| # Only frst two of _thnn_fused_lstm_cell outputs can have gradients. |
| # _thnn_fused_lstm_cell outputs: (hy, cy, workspace) |
| - name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) |
| output_differentiability: [True, True, False] |
| input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())" |
| |
| - name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) |
| input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>()" |
| |
| # PackedSequence helpers |
| - name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) |
| input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first) |
| |
| # TH wrappers |
| - name: eq.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: eq.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: ge.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: ge.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: gt.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: gt.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: le.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: le.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: lt.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: lt.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: ne.Scalar(Tensor self, Scalar other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: ne.Tensor(Tensor self, Tensor other) -> Tensor |
| output_differentiability: [False] |
| |
| - name: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor |
| output_differentiability: [False] |
| |
| - name: nonzero(Tensor self) -> Tensor |
| output_differentiability: [False] |
| |
| - name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor |
| data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial) |
| |
| - name: _pin_memory(Tensor self, Device? device=None) -> Tensor |
| self: grad |
| |
| - name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor |
| self: non_differentiable |
| other: non_differentiable |
| output_differentiability: [False] |
| |
| - name: _test_warn_in_autograd(Tensor self) -> Tensor |
| self: warn_backwards(grad) |
| |
| - name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor |
| dispatch: |
| Default: |
| self: grad.expand_symint(self.sym_sizes()) + 1 |
| result: auto_linear |
| AutogradNestedTensor: |
| self: grad.mul(grad) |
| AutogradCUDA: |
| self: grad.expand_symint(self.sym_sizes()) * 2 |
| |
| - name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor |
| dispatch: |
| AutogradNestedTensor: |
| self: grad.mul(grad).add(grad) |
| |
| - name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) |
| dispatch: |
| Default: |
| self: grad.reshape_as(self) |
| AutogradCUDA: |
| self: grad.reshape_as(self) + 1 |
| |
| - name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
| output_differentiability: [False] |
| |
| - name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor |
| self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result) |
| index: non_differentiable |
| result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result) |
| |
| - name: special_airy_ai(Tensor x) -> Tensor |
| x: non_differentiable |
| |
| - name: special_bessel_j0(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_bessel_j1(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_bessel_y0(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_bessel_y1(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_modified_bessel_i0(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_modified_bessel_i1(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_modified_bessel_k0(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_modified_bessel_k1(Tensor self) -> Tensor |
| self: non_differentiable |
| |
| - name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor |
| x: non_differentiable |
| |
| - name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor |
| x: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor |
| x: non_differentiable |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor |
| n: non_differentiable |
| |
| - name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor |
| x: non_differentiable |
| |
| - name: special_spherical_bessel_j0(Tensor x) -> Tensor |
| x: non_differentiable |
| |
| - name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor |
| self: grad.reshape_symint(self.sym_sizes()) |
| result: auto_linear |