blob: 366b30c39eaa26b23e5403b45f6f9de9caeaef75 [file] [log] [blame]
# ATen native functions are a mechanism to write ATen methods which only
# make use of other ATen operations (e.g., it is not necessary to bind into
# TH/THC code). These functions are declared in this file and then folded
# into the ATen code generation process.
#
# The simple format is as follows:
# - func: func_name(ArgType arg0[=default], ArgType arg1[=default], ...) -> ReturnType
# ArgType(s) are allowed to be simple types understood by ATen
# (e.g. Tensor, TensorList, IntList, int64_t, double).
# ReturnType is allowed to be any ArgType or tuple combination of ArgTypes(s), e.g. (Tensor, Tensor)
# defaults are optional and are only allowed to be numbers (e.g. '0' for int64_t, '5.0' for double)
#
# The declarations also support the following attributes:
# variants: [function, method] by default; controls whether Tensor method or namespace Function is generated
# as a result of this declaration.
# dispatch: equal to the func_name by default; this can be overridden by providing a
# backend-specific name to dispatch to, e.g.:
# CPU: func_cpu
# CUDA: func_cuda
# python_default_init: a map from argument names to default initialize expressions in C++. Such default
# expressions will only be used in Python API. This allows us to write argument with
# a default value that can either cause ambiguity in c++ (e.g., `Scalar p` argument
# in `norm`) or have a type that doesn't allow default value None/NULL/nullptr (e.g.,
# `int64_t fft_size` argument in stft, which we want to default to value of another
# argument if not provided).
#
# In addition to the variants generation, these declarations will also generate a C++ function declaration
# in NativeFunctions.h; it is up to you to write the corresponding definitions in the correct file under /native
# (the exact file depends on if you are writing a generic (all backend), or CPU/CUDA specific implementation).
# Note that these generated C++ function declarations won't match the declaration here because they will
# undergo the standard ATen C++ transformations, e.g. use of const-ref for non-inplace Tensor arguments. It is
# recommended that you copy the generated C++ function declaration to your definition so that the two
# match.
- func: adaptive_avg_pool1d(Tensor self, IntList[1] output_size) -> Tensor
variants: function
- func: adaptive_max_pool1d(Tensor self, IntList[1] output_size) -> (Tensor, Tensor)
variants: function
- func: allclose(Tensor self, Tensor other, double rtol=1e-5, double atol=1e-8) -> bool
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
variants: function
- func: bernoulli_(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor
- func: bernoulli_(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor
- func: chunk(Tensor self, int64_t chunks, int64_t dim=0) -> TensorList
- func: cudnn_is_acceptable(Tensor self) -> bool
variants: function
- func: convolution(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups) -> Tensor
variants: function
- func: _convolution(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
variants: function
- func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding) -> Tensor
variants: function
# NB: We MUST call the input self, otherwise codegen will attempt to
# dispatch on ggI... which might be undefined.
- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
variants: function
- func: conv1d(Tensor input, Tensor weight, Tensor bias={}, IntList[1] stride=1, IntList[1] padding=0, IntList[1] dilation=1, int64_t groups=1) -> Tensor
variants: function
- func: conv2d(Tensor input, Tensor weight, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] dilation=1, int64_t groups=1) -> Tensor
variants: function
- func: conv3d(Tensor input, Tensor weight, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0, IntList[3] dilation=1, int64_t groups=1) -> Tensor
variants: function
- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad) -> Tensor
- func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int64_t pad) -> (Tensor, Tensor, Tensor)
# NB: we inherit the goofy argument order from PyTorch torch.nn.functional
- func: conv_transpose1d(Tensor input, Tensor weight, Tensor bias={}, IntList[1] stride=1, IntList[1] padding=0, IntList[1] output_padding=0, int64_t groups=1, IntList[1] dilation=1) -> Tensor
variants: function
- func: conv_transpose2d(Tensor input, Tensor weight, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] output_padding=0, int64_t groups=1, IntList[2] dilation=1) -> Tensor
variants: function
- func: conv_transpose3d(Tensor input, Tensor weight, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0, IntList[3] output_padding=0, int64_t groups=1, IntList[3] dilation=1) -> Tensor
variants: function
- func: cudnn_affine_grid_generator(Tensor theta, int64_t N, int64_t C, int64_t H, int64_t W) -> Tensor
return:
- type: Tensor
name: grid
variants: function
dispatch: cudnn_affine_grid_generator_forward
# TODO: Why do I have to call this grad?!
- func: cudnn_affine_grid_generator_backward(Tensor grad, int64_t N, int64_t C, int64_t H, int64_t W)
return:
- type: Tensor
name: grad_theta
variants: function
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor)
variants: function
# NB: You can only use this if you used cudnn_batch_norm training=True
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor)
variants: function
- func: cudnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
- func: cudnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
- func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
variants: function
- func: cudnn_convolution_backward_bias(Tensor grad_output) -> Tensor
variants: function
- func: cudnn_convolution_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
- func: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
# NB: output_padding not strictly needed here, but it's helpful for the double
# backwards
- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
variants: function
- func: cudnn_convolution_transpose_backward_bias(Tensor grad_output) -> Tensor
variants: function
dispatch: cudnn_convolution_backward_bias
- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
- func: cudnn_convolution_transpose_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
variants: function
# NB: input is special cased in a way I don't quite understand
- func: cudnn_grid_sampler(Tensor self, Tensor grid)
return:
- type: Tensor
name: output
variants: function
dispatch: cudnn_grid_sampler_forward
- func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output)
return:
- type: Tensor
name: grad_self
- type: Tensor
name: grad_grid
variants: function
- func: det(Tensor self) -> Tensor
- func: _det_with_svd(Tensor self) -> (Tensor, Tensor, Tensor, Tensor)
- func: embedding(Tensor weight, IndexTensor indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) -> Tensor
variants: function
- func: embedding_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
variants: function
- func: embedding_dense_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor
variants: function
dispatch:
CPU: embedding_backward_cpu
CUDA: embedding_backward_cuda
- func: embedding_renorm_(Tensor self, IndexTensor indices, double max_norm, double norm_type) -> Tensor
variants: function
dispatch:
CPU: embedding_renorm_cpu_
CUDA: embedding_renorm_cuda_
- func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor
variants: function
- func: empty_like(Tensor self) -> Tensor
variants: function
- func: expand(Tensor self, IntList size) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
- func: expand_as(Tensor self, Tensor other) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
- func: index(Tensor self, TensorList indices) -> Tensor
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
- func: is_cuda(Tensor self) -> bool
- func: is_distributed(Tensor self) -> bool
- func: is_floating_point(Tensor self) -> bool
- func: is_nonzero(Tensor self) -> bool
- func: is_same_size(Tensor self, Tensor other) -> bool
- func: is_signed(Tensor self) -> bool
- func: is_sparse(Tensor self) -> bool
- func: matmul(Tensor self, Tensor other) -> Tensor
- func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor)
variants: function
- func: narrow(Tensor self, int64_t dim, int64_t start, int64_t length) -> Tensor
# TODO: Why does kW come before kH? Hella confusing, because it
# doesn't match the input layout.
- func: nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor
variants: function
- func: nnpack_spatial_convolution_backward(Tensor input, Tensor grad_output, Tensor weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
variants: function
- func: nnpack_spatial_convolution_backward_input(Tensor input, Tensor grad_output, Tensor weight, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor
variants: function
- func: nnpack_spatial_convolution_backward_weight(Tensor input, IntList weight_size, Tensor grad_output, int64_t kW, int64_t kH, int64_t padW, int64_t padH) -> Tensor
variants: function
- func: permute(Tensor self, IntList dims) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
- func: pin_memory(Tensor self) -> Tensor
- func: randn_like(Tensor self) -> Tensor
variants: function
- func: repeat(Tensor self, IntList repeats) -> Tensor
- func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor)
variants: function
dispatch:
CPU: RoiPooling2d_forward_cpu
CUDA: RoiPooling2d_forward_cuda
- func: RoiPooling2d_backward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale, Tensor gradOutput, Tensor argmaxes) -> Tensor
variants: function
dispatch:
CPU: RoiPooling2d_backward_cpu
CUDA: RoiPooling2d_backward_cuda
- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor
variants: function
- func: rrelu_(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor
variants: function
- func: select(Tensor self, int64_t dim, int64_t index) -> Tensor
- func: selu(Tensor self) -> Tensor
variants: function
- func: selu_(Tensor self) -> Tensor
variants: function
- func: size(Tensor self, int64_t dim) -> int64_t
- func: slice(Tensor self, int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) -> Tensor
- func: split(Tensor self, int64_t split_size, int64_t dim=0) -> TensorList
- func: squeeze(Tensor self) -> Tensor
- func: squeeze(Tensor self, int64_t dim) -> Tensor
- func: squeeze_(Tensor self) -> Tensor
- func: squeeze_(Tensor self, int64_t dim) -> Tensor
- func: stack(TensorList tensors, int64_t dim=0) -> Tensor
variants: function
- func: stft(Tensor self, int64_t frame_length, int64_t hop, int64_t fft_size, bool return_onesided=true, Tensor window={}, int64_t pad_end=0) -> Tensor
python_default_init:
fft_size: frame_length
- func: stride(Tensor self, int64_t dim) -> int64_t
- func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor
- func: t_(Tensor self) -> Tensor
- func: type_as(Tensor self, Tensor other) -> Tensor
- func: unsqueeze(Tensor self, int64_t dim) -> Tensor
- func: unsqueeze_(Tensor self, int64_t dim) -> Tensor
- func: view_as(Tensor self, Tensor other) -> Tensor
# we define both of these because 'where' does the broadcast and '_s_where' doesn't;
# this allows us to implicitly calculate the broadcast derivative, while only dealing with the
# _s_where derivative.
- func: where(BoolTensor condition, Tensor self, Tensor other) -> Tensor
- func: _s_where(BoolTensor condition, Tensor self, Tensor other) -> Tensor
dispatch:
CPU: _s_where_cpu
CUDA: _s_where_cuda
- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
dispatch:
CPU: _standard_gamma_grad_cpu
CUDA: _standard_gamma_grad_cuda
- func: poisson(Tensor self, Generator* generator=nullptr) -> Tensor
variants: function
dispatch:
CPU: _s_poisson_cpu
CUDA: _s_poisson_cuda