| # See README.md in this directory for more guidance |
| |
| |
| # Temporary type cast operators. These are needed to trace type-casts now since |
| # Type's are not supported in the IR. Instead, we call down to these |
| # specialized operators for each datatype. |
| # TODO: remove when we have Type support in the IR |
| - func: _cast_Byte(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Char(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Double(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Float(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Int(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Long(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Short(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cast_Half(Tensor self, bool non_blocking=false) -> Tensor |
| variants: function, method |
| |
| - func: _cudnn_rnn_flatten_weight(TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, bool bidirectional) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: _cudnn_rnn_flatten_weight |
| |
| - func: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, BoolTensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CUDA: _cudnn_rnn |
| |
| - func: _cudnn_rnn_backward(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, BoolTensor? dropout_state, Tensor reserve, std::array<bool,4> output_mask) -> (Tensor, Tensor, Tensor, TensorList) |
| variants: function |
| dispatch: |
| CUDA: _cudnn_rnn_backward |
| |
| - func: _cudnn_init_dropout_state(Type self_ty, double dropout, bool train, int64_t dropout_seed) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: _cudnn_init_dropout_state |
| |
| - func: abs(Tensor self) -> Tensor |
| |
| - func: abs_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _abs__cpu |
| CUDA: _abs__cuda |
| |
| - func: abs_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _abs_out_cpu |
| CUDA: _abs_out_cuda |
| |
| - func: acos(Tensor self) -> Tensor |
| |
| - func: acos_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _acos__cpu |
| CUDA: _acos__cuda |
| |
| - func: acos_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _acos_out_cpu |
| CUDA: _acos_out_cuda |
| |
| - func: avg_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, bool ceil_mode=false, bool count_include_pad=true) -> Tensor |
| variants: function |
| |
| - func: adaptive_avg_pool1d(Tensor self, IntList[1] output_size) -> Tensor |
| variants: function |
| |
| - func: adaptive_max_pool1d(Tensor self, IntList[1] output_size) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: allclose(Tensor self, Tensor other, double rtol=1e-5, double atol=1e-8, bool equal_nan=False) -> bool |
| device_guard: false |
| |
| - func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| |
| - func: addmv_(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| |
| - func: addmv_out(Tensor result, Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| |
| - func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| |
| - func: addr_(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: method |
| |
| - func: addr_out(Tensor result, Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| |
| - func: all(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| - func: all_out(Tensor result, Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| variants: function |
| |
| - func: any(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| - func: any_out(Tensor result, Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| variants: function |
| |
| - func: arange(Scalar start, Scalar end, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: arange(Scalar start, Scalar end, Scalar step, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: arange_out(Tensor result, Scalar start, Scalar end) -> Tensor |
| variants: function |
| |
| - func: arange_out(Tensor result, Scalar start, Scalar end, Scalar step) -> Tensor |
| variants: function |
| |
| - func: arange(Scalar end, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: arange_out(Tensor result, Scalar end) -> Tensor |
| variants: function |
| |
| - func: arange(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: arange(Type dtype, Scalar end) -> Tensor |
| variants: function |
| deprecated: true |
| |
| # This function is a temporary hack to allow tracing of arange like constructs with dynamic |
| # bounds on arange. Normal arange is not traceable because it does not take any tensor inputs; |
| # if the range you need is based on another tensor, calling this function directly will |
| # preserve tracing. Get rid of this when arange can directly take tensors for bounds |
| # (so that it can be traced directly). |
| - func: _dim_arange(Tensor like, int64_t dim) -> Tensor |
| variants: function |
| |
| # `argmin` and `argmax` are exposed in C++ but not in Python, where we only |
| # expose `_argmin` and `_argmax` (which call the first versions). In Python, we |
| # then define our own `argmax` and `argmin` that handle passing `dim=None`, |
| # which gets the argmax/argmin of the flattened array. |
| |
| - func: argmax(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| - func: argmax(Tensor self) -> Tensor |
| - func: _argmax(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| - func: argmin(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| - func: argmin(Tensor self) -> Tensor |
| - func: _argmin(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| # The actual implementations live in Declarations.cwrap. These are just to |
| # provide default values for storage_offset=self.storage_offset() |
| - func: as_strided(Tensor self, IntList size, IntList stride) -> Tensor |
| - func: as_strided_(Tensor self, IntList size, IntList stride) -> Tensor |
| |
| - func: asin(Tensor self) -> Tensor |
| |
| - func: asin_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _asin__cpu |
| CUDA: _asin__cuda |
| |
| - func: asin_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _asin_out_cpu |
| CUDA: _asin_out_cuda |
| |
| - func: atan(Tensor self) -> Tensor |
| |
| - func: atan_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _atan__cpu |
| CUDA: _atan__cuda |
| |
| - func: atan_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _atan_out_cpu |
| CUDA: _atan_out_cuda |
| |
| - func: bartlett_window(int64_t window_length, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: bartlett_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor |
| variants: function |
| |
| - func: bernoulli(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor |
| |
| - func: bernoulli(Tensor self, double p, Generator* generator=nullptr) -> Tensor |
| |
| - func: bernoulli(Tensor self) -> Tensor |
| |
| - func: bernoulli_(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor |
| |
| - func: bernoulli_(Tensor self, double p, Generator* generator=nullptr) -> Tensor |
| |
| - func: bernoulli_(Tensor self) -> Tensor |
| |
| - func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor |
| variants: function |
| |
| - func: bincount(Tensor self, Tensor? weights={}, int64_t minlength=0) -> Tensor |
| dispatch: |
| CPU: _bincount_cpu |
| CUDA: _bincount_cuda |
| |
| - func: blackman_window(int64_t window_length, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: blackman_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: cat(TensorList tensors, int64_t dim=0) -> Tensor |
| variants: function |
| |
| - func: cat_out(Tensor result, TensorList tensors, int64_t dim=0) -> Tensor |
| variants: function |
| |
| - func: ceil(Tensor self) -> Tensor |
| |
| - func: ceil_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _ceil__cpu |
| CUDA: _ceil__cuda |
| |
| - func: ceil_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _ceil_out_cpu |
| CUDA: _ceil_out_cuda |
| |
| - func: chunk(Tensor self, int64_t chunks, int64_t dim=0) -> TensorList |
| |
| - func: cudnn_is_acceptable(Tensor self) -> bool |
| variants: function |
| device_guard: false |
| |
| - func: convolution(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups) -> Tensor |
| variants: function |
| |
| - func: _convolution(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor |
| variants: function |
| |
| - func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding) -> Tensor |
| variants: function |
| |
| # NB: We MUST call the input self, otherwise codegen will attempt to |
| # dispatch on ggI... which might be undefined. |
| - func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, IntList stride, IntList padding, IntList dilation, bool transposed, IntList output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor) |
| variants: function |
| |
| - func: conv1d(Tensor input, Tensor weight, Tensor bias={}, IntList[1] stride=1, IntList[1] padding=0, IntList[1] dilation=1, int64_t groups=1) -> Tensor |
| variants: function |
| |
| - func: conv2d(Tensor input, Tensor weight, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] dilation=1, int64_t groups=1) -> Tensor |
| variants: function |
| |
| - func: conv3d(Tensor input, Tensor weight, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0, IntList[3] dilation=1, int64_t groups=1) -> Tensor |
| variants: function |
| |
| - func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad) -> Tensor |
| |
| - func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int64_t pad) -> (Tensor, Tensor, Tensor) |
| |
| # NB: we inherit the goofy argument order from PyTorch torch.nn.functional |
| - func: conv_transpose1d(Tensor input, Tensor weight, Tensor bias={}, IntList[1] stride=1, IntList[1] padding=0, IntList[1] output_padding=0, int64_t groups=1, IntList[1] dilation=1) -> Tensor |
| variants: function |
| |
| - func: conv_transpose2d(Tensor input, Tensor weight, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] output_padding=0, int64_t groups=1, IntList[2] dilation=1) -> Tensor |
| variants: function |
| |
| - func: conv_transpose3d(Tensor input, Tensor weight, Tensor bias={}, IntList[3] stride=1, IntList[3] padding=0, IntList[3] output_padding=0, int64_t groups=1, IntList[3] dilation=1) -> Tensor |
| variants: function |
| |
| - func: cos(Tensor self) -> Tensor |
| |
| - func: cos_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _cos__cpu |
| CUDA: _cos__cuda |
| |
| - func: cos_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _cos_out_cpu |
| CUDA: _cos_out_cuda |
| |
| - func: cosh(Tensor self) -> Tensor |
| |
| - func: cosh_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _cosh__cpu |
| CUDA: _cosh__cuda |
| |
| - func: cosh_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _cosh_out_cpu |
| CUDA: _cosh_out_cuda |
| |
| - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor |
| variants: function |
| |
| - func: cudnn_affine_grid_generator(Tensor theta, int64_t N, int64_t C, int64_t H, int64_t W) -> Tensor |
| return: |
| - type: Tensor |
| name: grid |
| variants: function |
| dispatch: |
| CUDA: cudnn_affine_grid_generator_forward |
| |
| # TODO: Why do I have to call this grad?! |
| - func: cudnn_affine_grid_generator_backward(Tensor grad, int64_t N, int64_t C, int64_t H, int64_t W) |
| return: |
| - type: Tensor |
| name: grad_theta |
| variants: function |
| dispatch: |
| CUDA: cudnn_affine_grid_generator_backward |
| |
| - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CUDA: cudnn_batch_norm |
| |
| # NB: You can only use this if you used cudnn_batch_norm training=True |
| - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CUDA: cudnn_batch_norm_backward |
| |
| - func: cudnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution |
| |
| - func: cudnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_backward_input |
| |
| - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_backward |
| |
| - func: cudnn_convolution_backward_bias(Tensor grad_output) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_backward_bias |
| |
| - func: cudnn_convolution_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_backward_weight |
| |
| - func: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_transpose |
| |
| # NB: output_padding not strictly needed here, but it's helpful for the double |
| # backwards |
| - func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_transpose_backward |
| |
| - func: cudnn_convolution_transpose_backward_bias(Tensor grad_output) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_backward_bias |
| |
| - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_transpose_backward_input |
| |
| - func: cudnn_convolution_transpose_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: cudnn_convolution_transpose_backward_weight |
| |
| # NB: input is special cased in a way I don't quite understand |
| - func: cudnn_grid_sampler(Tensor self, Tensor grid) |
| return: |
| - type: Tensor |
| name: output |
| variants: function |
| dispatch: |
| CUDA: cudnn_grid_sampler_forward |
| |
| - func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) |
| return: |
| - type: Tensor |
| name: grad_self |
| - type: Tensor |
| name: grad_grid |
| variants: function |
| dispatch: |
| CUDA: cudnn_grid_sampler_backward |
| |
| # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. |
| - func: cumsum(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| |
| - func: cumsum(Tensor self, int64_t dim) -> Tensor |
| |
| - func: cumsum_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: cumsum_out(Tensor result, Tensor self, int64_t dim) -> Tensor |
| variants: function |
| |
| # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. |
| - func: cumprod(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| |
| - func: cumprod(Tensor self, int64_t dim) -> Tensor |
| |
| - func: cumprod_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: cumprod_out(Tensor result, Tensor self, int64_t dim) -> Tensor |
| variants: function |
| |
| - func: det(Tensor self) -> Tensor |
| |
| - func: diagflat(Tensor self, int64_t offset=0) -> Tensor |
| variants: function |
| |
| - func: diagonal(Tensor self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) -> Tensor |
| |
| - func: dot(Tensor self, Tensor tensor) -> Tensor |
| |
| - func: dot_out(Tensor result, Tensor self, Tensor tensor) -> Tensor |
| variants: function |
| |
| - func: einsum(std::string equation, TensorList tensors) -> Tensor |
| variants: function |
| |
| - func: embedding(Tensor weight, IndexTensor indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) -> Tensor |
| variants: function |
| |
| - func: embedding_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor |
| variants: function |
| |
| - func: embedding_dense_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor |
| variants: function |
| dispatch: |
| CPU: embedding_dense_backward_cpu |
| CUDA: embedding_dense_backward_cuda |
| |
| - func: embedding_renorm_(Tensor self, IndexTensor indices, double max_norm, double norm_type) -> Tensor |
| variants: function |
| dispatch: |
| CPU: embedding_renorm_cpu_ |
| CUDA: embedding_renorm_cuda_ |
| |
| - func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor |
| variants: function |
| |
| # NOTE [ embedding_bag Native Functions ] |
| # The `_embedding_bag.*` variants assume that input tensors except for `weight`, |
| # e.g. `indices` and `offsets` (and `offset2bag`), are contiguous. |
| # We really only need to enforce this for `_embedding_bag` (the forward) because |
| # the backward inputs are the same as forward ones. |
| # The above `embedding_bag` wrapper is created to achieve this, e.g., |
| # applying indices = indices.contiguous(). |
| # The backward functions apply a check that these input tensors are contiguous. |
| |
| - func: embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor, Tensor) |
| variants: function |
| |
| - func: _embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor, Tensor) |
| variants: function |
| dispatch: |
| CPU: _embedding_bag_cpu |
| CUDA: _embedding_bag_cuda |
| |
| - func: _embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor |
| variants: function |
| |
| - func: _embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor |
| variants: function |
| |
| - func: _embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _embedding_bag_dense_backward_cpu |
| CUDA: _embedding_bag_dense_backward_cuda |
| |
| - func: empty(IntList size, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: empty_out(Tensor result, IntList size) -> Tensor |
| variants: function |
| |
| - func: empty_like(Tensor self) -> Tensor |
| variants: function |
| |
| - func: empty_like(Tensor self, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: empty(Type dtype, IntList size) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: erf(Tensor self) -> Tensor |
| |
| - func: erf_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _erf__cpu |
| CUDA: _erf__cuda |
| |
| - func: erf_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _erf_out_cpu |
| CUDA: _erf_out_cuda |
| |
| - func: erfc(Tensor self) -> Tensor |
| |
| - func: erfc_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _erfc__cpu |
| CUDA: _erfc__cuda |
| |
| - func: erfc_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _erfc_out_cpu |
| CUDA: _erfc_out_cuda |
| |
| - func: exp(Tensor self) -> Tensor |
| |
| - func: exp_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _exp__cpu |
| CUDA: _exp__cuda |
| |
| - func: exp_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _exp_out_cpu |
| CUDA: _exp_out_cuda |
| |
| - func: expm1(Tensor self) -> Tensor |
| |
| - func: expm1_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _expm1__cpu |
| CUDA: _expm1__cuda |
| |
| - func: expm1_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _expm1_out_cpu |
| CUDA: _expm1_out_cuda |
| |
| - func: expand(Tensor self, IntList size, *, bool implicit=false) -> Tensor |
| variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. |
| |
| - func: expand_as(Tensor self, Tensor other) -> Tensor |
| variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. |
| |
| - func: eye(int64_t n, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: eye(int64_t n, int64_t m, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: eye_out(Tensor result, int64_t n) -> Tensor |
| variants: function |
| dispatch: |
| CPU: eye_out_cpu |
| CUDA: eye_out_cuda |
| |
| - func: eye_out(Tensor result, int64_t n, int64_t m) -> Tensor |
| variants: function |
| dispatch: |
| CPU: eye_out_cpu |
| CUDA: eye_out_cuda |
| |
| - func: eye(Type dtype, int64_t n, int64_t m=-1) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: flatten(Tensor self, int64_t start_dim=0, int64_t end_dim=-1) -> Tensor |
| |
| - func: fill_(Tensor self, Scalar value) -> Tensor |
| |
| - func: fill_(Tensor self, Tensor value) -> Tensor |
| |
| - func: floor(Tensor self) -> Tensor |
| |
| - func: floor_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _floor__cpu |
| CUDA: _floor__cuda |
| |
| - func: floor_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _floor_out_cpu |
| CUDA: _floor_out_cuda |
| |
| - func: full(IntList size, Scalar fill_value, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: full_out(Tensor result, IntList size, Scalar fill_value) -> Tensor |
| variants: function |
| |
| - func: full_like(Tensor self, Scalar fill_value) -> Tensor |
| variants: function |
| |
| - func: full_like(Tensor self, Scalar fill_value, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: full(Type dtype, IntList size, Scalar fill_value) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: grid_sampler(Tensor input, Tensor grid, int64_t padding_mode) -> Tensor |
| variants: function |
| |
| - func: hann_window(int64_t window_length, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hann_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hamming_window(int64_t window_length, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hamming_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hamming_window(int64_t window_length, bool periodic, double alpha, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hamming_window(int64_t window_length, bool periodic, double alpha, double beta, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: hinge_embedding_loss(Tensor self, Tensor target, double margin=1.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor |
| variants: function |
| |
| - func: ger(Tensor self, Tensor vec2) -> Tensor |
| |
| - func: ger_out(Tensor result, Tensor self, Tensor vec2) -> Tensor |
| variants: function |
| |
| - func: gesv(Tensor self, Tensor A) -> (Tensor, Tensor) |
| |
| - func: gesv_out(Tensor solution, Tensor lu, Tensor self, Tensor A) -> (Tensor, Tensor) |
| variants: function |
| |
| # gesv handles broadcasting of arbitrary batch dims while _gesv_helper does not. |
| - func: _gesv_helper(Tensor self, Tensor A) -> (Tensor, Tensor) |
| dispatch: |
| CPU: _gesv_helper_cpu |
| CUDA: _gesv_helper_cuda |
| |
| - func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enabled=True) -> Tensor |
| variants: function |
| |
| # FFT |
| |
| - func: fft(Tensor self, int64_t signal_ndim, bool normalized=false) -> Tensor |
| |
| - func: ifft(Tensor self, int64_t signal_ndim, bool normalized=false) -> Tensor |
| |
| - func: rfft(Tensor self, int64_t signal_ndim, bool normalized=false, bool onesided=true) -> Tensor |
| |
| - func: irfft(Tensor self, int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) -> Tensor |
| |
| - func: _fft_with_size(Tensor self, int64_t signal_ndim, bool complex_input, bool complex_output, bool inverse, IntList checked_signal_sizes, bool normalized, bool onesided, IntList output_sizes) -> Tensor |
| dispatch: |
| CPU: _fft_mkl |
| CUDA: _fft_cufft |
| |
| - func: _cufft_get_plan_cache_size() -> int64_t |
| variants: function |
| device_guard: false |
| |
| - func: _cufft_get_plan_cache_max_size() -> int64_t |
| variants: function |
| device_guard: false |
| |
| - func: _cufft_set_plan_cache_max_size(int64_t max_size) |
| variants: function |
| device_guard: false |
| |
| - func: _cufft_clear_plan_cache() |
| variants: function |
| device_guard: false |
| |
| - func: index(Tensor self, TensorList indices) -> Tensor |
| # NB: This function is special-cased in tools/autograd/gen_variable_type.py |
| |
| - func: index_copy_(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor |
| variants: method |
| |
| - func: index_put(Tensor self, TensorList indices, Tensor values) -> Tensor |
| |
| - func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor |
| |
| - func: isclose(Tensor self, Tensor other, double rtol=1e-5, double atol=1e-8, bool equal_nan=False) -> Tensor |
| |
| - func: is_cuda(Tensor self) -> bool |
| device_guard: false |
| |
| - func: is_distributed(Tensor self) -> bool |
| device_guard: false |
| |
| - func: is_floating_point(Tensor self) -> bool |
| device_guard: false |
| |
| - func: is_nonzero(Tensor self) -> bool |
| device_guard: false |
| |
| - func: is_same_size(Tensor self, Tensor other) -> bool |
| device_guard: false |
| |
| - func: is_signed(Tensor self) -> bool |
| device_guard: false |
| |
| - func: is_sparse(Tensor self) -> bool |
| device_guard: false |
| |
| - func: kthvalue(Tensor self, int64_t k, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) |
| |
| - func: kthvalue_out(Tensor values, Tensor indices, Tensor self, int64_t k, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: layer_norm(Tensor input, IntList normalized_shape, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enable=True) -> Tensor |
| variants: function |
| |
| - func: linspace(Scalar start, Scalar end, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: linspace(Scalar start, Scalar end, int64_t steps, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: linspace_out(Tensor result, Scalar start, Scalar end) -> Tensor |
| variants: function |
| |
| - func: linspace_out(Tensor result, Scalar start, Scalar end, int64_t steps) -> Tensor |
| variants: function |
| |
| - func: linspace(Type dtype, Scalar start, Scalar end, int64_t steps=100) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: log(Tensor self) -> Tensor |
| |
| - func: log_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _log__cpu |
| CUDA: _log__cuda |
| |
| - func: log_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _log_out_cpu |
| CUDA: _log_out_cuda |
| |
| - func: log10(Tensor self) -> Tensor |
| |
| - func: log10_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _log10__cpu |
| CUDA: _log10__cuda |
| |
| - func: log10_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _log10_out_cpu |
| CUDA: _log10_out_cuda |
| |
| - func: log1p(Tensor self) -> Tensor |
| |
| - func: log1p_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _log1p__cpu |
| CUDA: _log1p__cuda |
| SparseCPU: log1p_sparse_ |
| SparseCUDA: log1p_sparse_ |
| |
| - func: log1p_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _log1p_out_cpu |
| CUDA: _log1p_out_cuda |
| SparseCPU: log1p_out_sparse |
| SparseCUDA: log1p_out_sparse |
| |
| - func: log2(Tensor self) -> Tensor |
| |
| - func: log2_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _log2__cpu |
| CUDA: _log2__cuda |
| |
| - func: log2_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _log2_out_cpu |
| CUDA: _log2_out_cuda |
| |
| - func: logdet(Tensor self) -> Tensor |
| |
| - func: logspace(Scalar start, Scalar end, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: logspace(Scalar start, Scalar end, int64_t steps, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: logspace_out(Tensor result, Scalar start, Scalar end) -> Tensor |
| variants: function |
| |
| - func: logspace_out(Tensor result, Scalar start, Scalar end, int64_t steps) -> Tensor |
| variants: function |
| |
| - func: logspace(Type dtype, Scalar start, Scalar end, int64_t steps=100) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: log_softmax(Tensor self, int64_t dim) -> Tensor |
| dispatch: |
| CPU: log_softmax_cpu |
| CUDA: log_softmax_cuda |
| |
| - func: log_softmax_backward_data(Tensor grad_output, Tensor output, int64_t dim, Tensor self) -> Tensor |
| dispatch: |
| CPU: log_softmax_backward_cpu |
| CUDA: log_softmax_backward_cuda |
| |
| - func: logsumexp(Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| |
| - func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| variants: function |
| |
| - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor |
| variants: function |
| |
| - func: matmul(Tensor self, Tensor other) -> Tensor |
| |
| - func: matmul_out(Tensor result, Tensor self, Tensor other) -> Tensor |
| variants: function |
| |
| - func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| |
| - func: max_out(Tensor max, Tensor max_values, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: max_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| - func: max_pool1d_with_indices(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> Tensor |
| variants: function |
| |
| - func: max_pool2d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> Tensor |
| variants: function |
| |
| - func: max_pool3d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> Tensor |
| variants: function |
| |
| # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. |
| - func: mean(Tensor self, *, ScalarType dtype) -> Tensor |
| |
| - func: mean(Tensor self) -> Tensor |
| |
| - func: mean(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| |
| - func: mean(Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| |
| - func: mean(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| |
| - func: mean_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: mean_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| variants: function |
| |
| - func: mean_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: median(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| |
| - func: median_out(Tensor values, Tensor indices, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: min(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| |
| - func: min_out(Tensor min, Tensor min_indices, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: min_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor |
| |
| - func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation) -> Tensor |
| variants: function |
| |
| - func: mkldnn_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> Tensor |
| variants: function |
| |
| - func: mkldnn_convolution_backward_weights(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, bool bias_defined) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor) |
| variants: function |
| |
| - func: mm(Tensor self, Tensor mat2) -> Tensor |
| |
| - func: mm_out(Tensor result, Tensor self, Tensor mat2) -> Tensor |
| variants: function |
| |
| - func: mode(Tensor self, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) |
| |
| - func: mode_out(Tensor values, Tensor indices, Tensor self, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) |
| variants: function |
| |
| - func: mv(Tensor self, Tensor vec) -> Tensor |
| |
| - func: mv_out(Tensor result, Tensor self, Tensor vec) -> Tensor |
| variants: function |
| |
| - func: narrow(Tensor self, int64_t dim, int64_t start, int64_t length) -> Tensor |
| |
| - func: ones(IntList size, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: ones_out(Tensor result, IntList size) -> Tensor |
| variants: function |
| |
| - func: ones_like(Tensor self) -> Tensor |
| variants: function |
| |
| - func: ones_like(Tensor self, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: ones(Type dtype, IntList size) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: pairwise_distance(Tensor x1, Tensor x2, double p=2, double eps=1e-6, bool keepdim=false) -> Tensor |
| variants: function |
| |
| - func: permute(Tensor self, IntList dims) -> Tensor |
| variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. |
| |
| - func: pin_memory(Tensor self) -> Tensor |
| |
| - func: pinverse(Tensor self, double rcond=1e-15) -> Tensor |
| |
| - func: rand(IntList size, *, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: rand(IntList size, *, Generator* generator, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: rand_out(Tensor result, IntList size, *) -> Tensor |
| variants: function |
| |
| - func: rand_out(Tensor result, IntList size, *, Generator* generator) -> Tensor |
| variants: function |
| |
| - func: rand_like(Tensor self) -> Tensor |
| variants: function |
| |
| - func: rand_like(Tensor self, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: rand(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: randint(int64_t high, IntList size, *, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randint(int64_t high, IntList size, *, Generator* generator, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randint(int64_t low, int64_t high, IntList size, *, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randint(int64_t low, int64_t high, IntList size, *, Generator* generator, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: randint(Type dtype, int64_t low, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: randint_out(Tensor result, int64_t high, IntList size, *) -> Tensor |
| variants: function |
| |
| - func: randint_out(Tensor result, int64_t high, IntList size, *, Generator* generator) -> Tensor |
| variants: function |
| |
| - func: randint_out(Tensor result, int64_t low, int64_t high, IntList size, *) -> Tensor |
| variants: function |
| |
| - func: randint_out(Tensor result, int64_t low, int64_t high, IntList size, *, Generator* generator) -> Tensor |
| variants: function |
| |
| - func: randint_like(Tensor self, int64_t high) -> Tensor |
| variants: function |
| |
| - func: randint_like(Tensor self, int64_t low, int64_t high) -> Tensor |
| variants: function |
| |
| - func: randint_like(Tensor self, int64_t high, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: randint_like(Tensor self, int64_t low, int64_t high, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: randn(IntList size, *, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randn(IntList size, *, Generator* generator, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randn_out(Tensor result, IntList size, *) -> Tensor |
| variants: function |
| |
| - func: randn_out(Tensor result, IntList size, *, Generator* generator) -> Tensor |
| variants: function |
| |
| - func: randn_like(Tensor self) -> Tensor |
| variants: function |
| |
| - func: randn_like(Tensor self, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: randperm(int64_t n, *, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randperm(int64_t n, *, Generator* generator, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: randperm_out(Tensor result, int64_t n, *) -> Tensor |
| variants: function |
| |
| - func: randperm_out(Tensor result, int64_t n, *, Generator* generator) -> Tensor |
| variants: function |
| dispatch: |
| CPU: randperm_out_cpu |
| CUDA: randperm_out_cuda |
| |
| - func: randperm(Type dtype, int64_t n, *, Generator* generator=nullptr) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: range(Scalar start, Scalar end, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: range(Scalar start, Scalar end, Scalar step, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: range_out(Tensor result, Scalar start, Scalar end) -> Tensor |
| variants: function |
| |
| - func: range_out(Tensor result, Scalar start, Scalar end, Scalar step) -> Tensor |
| variants: function |
| |
| - func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: repeat(Tensor self, IntList repeats) -> Tensor |
| variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. |
| |
| - func: reshape(Tensor self, IntList shape) -> Tensor |
| |
| - func: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale) -> (Tensor, Tensor) |
| variants: function |
| dispatch: |
| CPU: RoiPooling2d_forward_cpu |
| CUDA: RoiPooling2d_forward_cuda |
| |
| - func: RoiPooling2d_backward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale, Tensor gradOutput, Tensor argmaxes) -> Tensor |
| variants: function |
| dispatch: |
| CPU: RoiPooling2d_backward_cpu |
| CUDA: RoiPooling2d_backward_cuda |
| |
| - func: round(Tensor self) -> Tensor |
| |
| - func: round_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _round__cpu |
| CUDA: _round__cuda |
| |
| - func: round_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _round_out_cpu |
| CUDA: _round_out_cuda |
| |
| - func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor |
| variants: function |
| |
| - func: rrelu_(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor |
| variants: function |
| |
| - func: relu(Tensor self) -> Tensor |
| |
| - func: relu_(Tensor self) -> Tensor |
| |
| - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor |
| dispatch: |
| CPU: hardshrink_cpu |
| CUDA: hardshrink_cuda |
| |
| - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor |
| dispatch: |
| CPU: hardshrink_backward_cpu |
| CUDA: hardshrink_backward_cuda |
| |
| - func: rsqrt(Tensor self) -> Tensor |
| |
| - func: rsqrt_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _rsqrt__cpu |
| CUDA: _rsqrt__cuda |
| |
| - func: rsqrt_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _rsqrt_out_cpu |
| CUDA: _rsqrt_out_cuda |
| |
| - func: select(Tensor self, int64_t dim, int64_t index) -> Tensor |
| |
| - func: selu(Tensor self) -> Tensor |
| variants: function |
| |
| - func: selu_(Tensor self) -> Tensor |
| variants: function |
| |
| - func: sigmoid(Tensor self) -> Tensor |
| |
| - func: sigmoid_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _sigmoid__cpu |
| CUDA: _sigmoid__cuda |
| |
| - func: sigmoid_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _sigmoid_out_cpu |
| CUDA: _sigmoid_out_cuda |
| |
| - func: sin(Tensor self) -> Tensor |
| |
| - func: sin_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _sin__cpu |
| CUDA: _sin__cuda |
| |
| - func: sin_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _sin_out_cpu |
| CUDA: _sin_out_cuda |
| |
| - func: sinh(Tensor self) -> Tensor |
| |
| - func: sinh_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _sinh__cpu |
| CUDA: _sinh__cuda |
| |
| - func: sinh_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _sinh_out_cpu |
| CUDA: _sinh_out_cuda |
| |
| - func: size(Tensor self, int64_t dim) -> int64_t |
| device_guard: false |
| |
| - func: slice(Tensor self, int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) -> Tensor |
| |
| - func: slogdet(Tensor self) -> (Tensor, Tensor) |
| |
| - func: smm(Tensor self, Tensor mat2) -> Tensor |
| |
| - func: softmax(Tensor self, int64_t dim) -> Tensor |
| dispatch: |
| CPU: softmax_cpu |
| CUDA: softmax_cuda |
| |
| - func: softmax_backward_data(Tensor grad_output, Tensor output, int64_t dim, Tensor self) -> Tensor |
| dispatch: |
| CPU: softmax_backward_cpu |
| CUDA: softmax_backward_cuda |
| |
| - func: split(Tensor self, int64_t split_size, int64_t dim=0) -> TensorList |
| |
| - func: split_with_sizes(Tensor self, IntList split_sizes, int64_t dim=0) -> TensorList |
| |
| - func: squeeze(Tensor self) -> Tensor |
| |
| - func: squeeze(Tensor self, int64_t dim) -> Tensor |
| |
| - func: squeeze_(Tensor self) -> Tensor |
| variants: method |
| |
| - func: squeeze_(Tensor self, int64_t dim) -> Tensor |
| variants: method |
| |
| - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| |
| - func: sspaddmm_out(Tensor result, Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _sspaddmm_out_only_sparse |
| CUDA: _sspaddmm_out_only_sparse_cuda |
| SparseCPU: _sspaddmm_out_cpu |
| SparseCUDA: _sspaddmm_out_cuda |
| |
| - func: stack(TensorList tensors, int64_t dim=0) -> Tensor |
| variants: function |
| |
| - func: stack_out(Tensor result, TensorList tensors, int64_t dim=0) -> Tensor |
| variants: function |
| |
| - func: stft(Tensor self, int64_t frame_length, int64_t hop, int64_t fft_size, bool normalized=false, bool onesided=true, Tensor? window={}, int64_t pad_end=0) -> Tensor |
| python_default_init: |
| fft_size: frame_length |
| |
| - func: stride(Tensor self, int64_t dim) -> int64_t |
| device_guard: false |
| |
| # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. |
| - func: sum(Tensor self, *, ScalarType dtype) -> Tensor |
| |
| - func: sum(Tensor self) -> Tensor |
| |
| - func: _sum(Tensor self) -> Tensor |
| dispatch: |
| CPU: _sum_cpu |
| CUDA: _sum_cuda |
| |
| - func: sum(Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| |
| - func: sum(Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor |
| |
| - func: sum(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor |
| |
| - func: _sum(Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor |
| |
| - func: sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor |
| variants: function |
| |
| - func: sum_out(Tensor result, Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: _sum_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor |
| variants: function |
| |
| - func: _sum_cuda_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| variants: function |
| dispatch: |
| CUDA: _sum_out_cuda |
| |
| - func: sqrt(Tensor self) -> Tensor |
| |
| - func: sqrt_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _sqrt__cpu |
| CUDA: _sqrt__cuda |
| |
| - func: sqrt_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _sqrt_out_cpu |
| CUDA: _sqrt_out_cuda |
| |
| - func: std(Tensor self, bool unbiased=true) -> Tensor |
| |
| - func: std(Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor |
| |
| - func: std_out(Tensor result, Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor |
| variants: function |
| |
| # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593. |
| - func: prod(Tensor self, *, ScalarType dtype) -> Tensor |
| |
| - func: prod(Tensor self) -> Tensor |
| |
| - func: _prod(Tensor self) -> Tensor |
| dispatch: |
| CPU: _prod_cpu |
| CUDA: _prod_cuda |
| |
| - func: prod(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| |
| - func: prod(Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| |
| - func: prod(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| |
| - func: _prod(Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| |
| - func: prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| variants: function |
| |
| - func: prod_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor |
| variants: function |
| |
| - func: _prod_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _prod_out_cpu |
| CUDA: _prod_out_cuda |
| |
| - func: t(Tensor self) -> Tensor |
| |
| - func: t_(Tensor self) -> Tensor |
| variants: method |
| |
| - func: tan(Tensor self) -> Tensor |
| |
| - func: tan_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _tan__cpu |
| CUDA: _tan__cuda |
| |
| - func: tan_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _tan_out_cpu |
| CUDA: _tan_out_cuda |
| |
| - func: tanh(Tensor self) -> Tensor |
| |
| - func: tanh_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _tanh__cpu |
| CUDA: _tanh__cuda |
| |
| - func: tanh_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _tanh_out_cpu |
| CUDA: _tanh_out_cuda |
| |
| - func: transpose(Tensor self, int64_t dim0, int64_t dim1) -> Tensor |
| |
| - func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor |
| variants: method |
| |
| - func: flip(Tensor self, IntList dims) -> Tensor |
| dispatch: |
| CPU: flip_cpu |
| CUDA: flip_cuda |
| |
| - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim=1) -> Tensor |
| variants: function |
| |
| - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, double margin=1.0, double p=2, double eps=1e-6, bool swap=false, int64_t reduction=Reduction::ElementwiseMean) -> Tensor |
| variants: function |
| |
| - func: trunc(Tensor self) -> Tensor |
| |
| - func: trunc_(Tensor self) -> Tensor |
| dispatch: |
| CPU: _trunc__cpu |
| CUDA: _trunc__cuda |
| |
| - func: trunc_out(Tensor result, Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _trunc_out_cpu |
| CUDA: _trunc_out_cuda |
| |
| - func: type_as(Tensor self, Tensor other) -> Tensor |
| variants: method |
| |
| - func: _unique(Tensor self, bool sorted=false, bool return_inverse=false) -> (Tensor, Tensor) |
| dispatch: |
| CPU: _unique_cpu |
| CUDA: _unique_cuda |
| |
| - func: _unsafe_view(Tensor self, IntList size) -> Tensor |
| variants: function |
| |
| - func: unsqueeze(Tensor self, int64_t dim) -> Tensor |
| |
| - func: unsqueeze_(Tensor self, int64_t dim) -> Tensor |
| variants: method |
| |
| - func: var(Tensor self, bool unbiased=true) -> Tensor |
| |
| - func: var(Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor |
| |
| - func: var_out(Tensor result, Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor |
| variants: function |
| |
| - func: view_as(Tensor self, Tensor other) -> Tensor |
| variants: method |
| |
| # we define both of these because 'where' does the broadcast and '_s_where' doesn't; |
| # this allows us to implicitly calculate the broadcast derivative, while only dealing with the |
| # _s_where derivative. |
| - func: where(BoolTensor condition, Tensor self, Tensor other) -> Tensor |
| - func: _s_where(BoolTensor condition, Tensor self, Tensor other) -> Tensor |
| dispatch: |
| CPU: _s_where_cpu |
| CUDA: _s_where_cuda |
| |
| - func: zeros(IntList size, TensorOptions options={}) -> Tensor |
| variants: function |
| |
| - func: zeros_out(Tensor result, IntList size) -> Tensor |
| variants: function |
| |
| - func: zeros_like(Tensor self) -> Tensor |
| variants: function |
| |
| - func: zeros_like(Tensor self, *, TensorOptions options) -> Tensor |
| variants: function |
| |
| - func: zeros(Type dtype, IntList size) -> Tensor |
| variants: function |
| deprecated: true |
| |
| - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor |
| dispatch: |
| CPU: _standard_gamma_grad_cpu |
| CUDA: _standard_gamma_grad_cuda |
| |
| - func: _standard_gamma(Tensor self, Generator* generator=nullptr) -> Tensor |
| dispatch: |
| CPU: _s_gamma_cpu |
| CUDA: _s_gamma_cuda |
| |
| - func: poisson(Tensor self, Generator* generator=nullptr) -> Tensor |
| variants: function |
| dispatch: |
| CPU: _s_poisson_cpu |
| CUDA: _s_poisson_cuda |
| |
| # When more variants get ported to native, this dispatch will get more |
| # complicated |
| |
| - func: native_norm(Tensor self, Scalar p=2) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: norm_sparse |
| SparseCUDA: norm_sparse |
| |
| - func: norm(Tensor self, Scalar p=2) -> Tensor |
| variants: method, function |
| |
| - func: norm(Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor |
| python_default_init: |
| p: 2 |
| |
| - func: norm_out(Tensor result, Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor |
| variants: function |
| python_default_init: |
| p: 2 |
| |
| - func: native_clone(Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: clone_sparse |
| SparseCUDA: clone_sparse |
| |
| - func: clone(Tensor self) -> Tensor |
| |
| - func: native_resize_as_(Tensor self, Tensor the_template) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: resize_as_sparse_ |
| SparseCUDA: resize_as_sparse_ |
| |
| - func: resize_as_(Tensor self, Tensor the_template) -> Tensor |
| |
| - func: native_pow_out(Tensor result, Tensor self, Scalar exponent) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: pow_out_sparse_scalar |
| SparseCUDA: pow_out_sparse_scalar |
| |
| - func: native_pow(Tensor self, Scalar exponent) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: pow_sparse_scalar |
| SparseCUDA: pow_sparse_scalar |
| |
| - func: pow_out(Tensor result, Tensor self, Scalar exponent) -> Tensor |
| variants: function |
| |
| - func: pow(Tensor self, Scalar exponent) -> Tensor |
| variants: method, function |
| |
| - func: native_zero_(Tensor self) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: zero_sparse_ |
| SparseCUDA: zero_sparse_ |
| |
| - func: zero_(Tensor self) -> Tensor |
| |
| - func: s_native_add_out(Tensor result, Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_add_out_sparse_cpu |
| SparseCUDA: s_add_out_sparse_cuda |
| |
| - func: native_add_out(Tensor result, Tensor self, SparseTensorRef other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: add_out_dense_sparse_cpu |
| CUDA: add_out_dense_sparse_cuda |
| |
| - func: s_native_add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_add_sparse_cpu |
| SparseCUDA: s_add_sparse_cuda |
| |
| - func: native_add(Tensor self, SparseTensorRef other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: add_dense_sparse_cpu |
| CUDA: add_dense_sparse_cuda |
| |
| - func: s_native_add_(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_add_sparse_cpu_ |
| SparseCUDA: s_add_sparse_cuda_ |
| |
| - func: native_add_(Tensor self, SparseTensorRef other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: add_dense_sparse_cpu_ |
| CUDA: add_dense_sparse_cuda_ |
| |
| - func: add_out(Tensor result, Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| |
| - func: add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: method, function |
| |
| - func: add_(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: method |
| |
| |
| |
| - func: s_native_sub_out(Tensor result, Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_sub_out_sparse_cpu |
| SparseCUDA: s_sub_out_sparse_cuda |
| |
| - func: s_native_sub(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_sub_sparse_cpu |
| SparseCUDA: s_sub_sparse_cuda |
| |
| - func: s_native_sub_(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_sub_sparse_cpu_ |
| SparseCUDA: s_sub_sparse_cuda_ |
| |
| - func: sub_out(Tensor result, Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: function |
| |
| - func: sub(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: method, function |
| |
| - func: sub_(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor |
| variants: method |
| |
| |
| |
| - func: s_native_mul_out(Tensor result, Tensor self, Tensor other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_mul_out_sparse_cpu |
| SparseCUDA: s_mul_out_sparse_cuda |
| |
| - func: s_native_mul(Tensor self, Tensor other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_mul_sparse_cpu |
| SparseCUDA: s_mul_sparse_cuda |
| |
| - func: s_native_mul_(Tensor self, Tensor other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: s_mul_sparse_cpu_ |
| SparseCUDA: s_mul_sparse_cuda_ |
| |
| - func: native_mul_out(Tensor result, Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: mul_out_sparse_scalar |
| SparseCUDA: mul_out_sparse_scalar |
| |
| - func: native_mul(Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: mul_sparse_scalar |
| SparseCUDA: mul_sparse_scalar |
| |
| - func: native_mul_(Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: mul_sparse_scalar_ |
| SparseCUDA: mul_sparse_scalar_ |
| |
| - func: mul_out(Tensor result, Tensor self, Tensor other) -> Tensor |
| variants: function |
| |
| - func: mul_out(Tensor result, Tensor self, Scalar other) -> Tensor |
| variants: function |
| |
| - func: mul(Tensor self, Tensor other) -> Tensor |
| variants: method, function |
| |
| - func: mul(Tensor self, Scalar other) -> Tensor |
| variants: method, function |
| |
| - func: mul_(Tensor self, Tensor other) -> Tensor |
| variants: method |
| |
| - func: mul_(Tensor self, Scalar other) -> Tensor |
| variants: method |
| |
| |
| |
| - func: native_div_out(Tensor result, Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: div_out_sparse_scalar |
| SparseCUDA: div_out_sparse_scalar |
| |
| - func: native_div(Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: div_sparse_scalar |
| SparseCUDA: div_sparse_scalar |
| |
| - func: native_div_(Tensor self, Scalar other) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: div_sparse_scalar_ |
| SparseCUDA: div_sparse_scalar_ |
| |
| - func: div_out(Tensor result, Tensor self, Scalar other) -> Tensor |
| variants: function |
| |
| - func: div(Tensor self, Scalar other) -> Tensor |
| variants: method, function |
| |
| - func: div_(Tensor self, Scalar other) -> Tensor |
| variants: method |
| |
| |
| - func: s_native_addmm_out(Tensor result, Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: s_addmm_out_sparse_dense_cpu |
| CUDA: s_addmm_out_sparse_dense_cuda |
| |
| - func: s_native_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: s_addmm_sparse_dense_cpu |
| CUDA: s_addmm_sparse_dense_cuda |
| |
| - func: s_native_addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| dispatch: |
| CPU: s_addmm_sparse_dense_cpu_ |
| CUDA: s_addmm_sparse_dense_cuda_ |
| |
| - func: addmm_out(Tensor result, Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: function |
| |
| - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: method, function |
| |
| - func: addmm_(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor |
| variants: method |
| |
| |
| - func: native_tensor(Type self_ty) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: new_sparse |
| SparseCUDA: new_sparse |
| |
| - func: native_tensor(Type self_ty, IntList size) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: new_with_size_sparse |
| SparseCUDA: new_with_size_sparse |
| |
| - func: tensor(Type dtype) -> Tensor |
| variants: [] |
| |
| - func: tensor(Type dtype, IntList size) -> Tensor |
| variants: [] |
| |
| |
| # NB: I have to decompose sparse_coo_tensor into two functions, because |
| # it has custom dispatch logic for which Type to dispatch on (we must |
| # use the sparse equivalent of the type of the SECOND argument). |
| # |
| # The actual dispatcher, native_sparse_coo_tensor, has all of its overloads |
| # removed so you don't accidentally trigger the default behavior, which |
| # is to infer Type based on the first argument (indices), which is ~never |
| # what you want. (I guess hypothetically it would work; you'd |
| # just only ever dispatch to CPULongTensor or CUDALongTensor, but that |
| # seems a bit too finely balanced.) |
| |
| - func: native_sparse_coo_tensor(IndexTensor indices, Tensor values) -> Tensor |
| variants: [] |
| dispatch: |
| SparseCPU: new_with_tensor_sparse |
| SparseCUDA: new_with_tensor_sparse |
| |
| - func: native_sparse_coo_tensor(IndexTensor indices, Tensor values, IntList size) -> Tensor |
| variants: [] |
| dispatch: |
| SparseCPU: new_with_tensor_and_size_sparse |
| SparseCUDA: new_with_tensor_and_size_sparse |
| |
| - func: sparse_coo_tensor(IndexTensor indices, Tensor values) -> Tensor |
| variants: [] |
| |
| - func: sparse_coo_tensor(IndexTensor indices, Tensor values, IntList size) -> Tensor |
| variants: [] |
| |
| |
| - func: _native_sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, IntList size) -> Tensor |
| variants: [] |
| dispatch: |
| SparseCPU: new_with_tensor_and_size_unsafe_sparse |
| SparseCUDA: new_with_tensor_and_size_unsafe_sparse |
| |
| - func: _sparse_coo_tensor_unsafe(IndexTensor indices, Tensor values, IntList size) -> Tensor |
| variants: function |
| |
| |
| - func: sparse_raw_resize_(Tensor self, IntList size, int64_t sparseDims, int64_t denseDims) -> Tensor |
| variants: method |
| dispatch: |
| SparseCPU: raw_resize_sparse_ |
| SparseCUDA: raw_resize_sparse_ |
| |
| |
| - func: _sparse_mask(Tensor self, SparseTensorRef mask) -> Tensor |
| variants: method |
| dispatch: |
| CPU: sparse_mask_cpu |
| CUDA: sparse_mask_cuda |
| |
| |
| - func: to_dense(Tensor self) -> Tensor |
| variants: method |
| dispatch: |
| SparseCPU: sparse_to_dense |
| SparseCUDA: sparse_to_dense |
| |
| |
| - func: _sparseDims(Tensor self) -> int64_t |
| variants: method |
| dispatch: |
| SparseCPU: _sparseDims_sparse |
| SparseCUDA: _sparseDims_sparse |
| device_guard: False |
| |
| # legacy method |
| - func: _dimI(Tensor self) -> int64_t |
| variants: method |
| dispatch: _sparseDims_sparse |
| device_guard: False |
| |
| |
| - func: _denseDims(Tensor self) -> int64_t |
| variants: method |
| dispatch: |
| SparseCPU: _denseDims_sparse |
| SparseCUDA: _denseDims_sparse |
| device_guard: False |
| |
| # legacy method |
| - func: _dimV(Tensor self) -> int64_t |
| variants: method |
| dispatch: _denseDims_sparse |
| device_guard: False |
| |
| |
| - func: _nnz(Tensor self) -> int64_t |
| variants: method |
| dispatch: |
| SparseCPU: _nnz_sparse |
| SparseCUDA: _nnz_sparse |
| device_guard: False |
| |
| |
| - func: coalesce(Tensor self) -> Tensor |
| variants: method |
| dispatch: |
| SparseCPU: coalesce_sparse_cpu |
| SparseCUDA: coalesce_sparse_cuda |
| |
| |
| - func: is_coalesced(Tensor self) -> bool |
| variants: method |
| dispatch: |
| SparseCPU: is_coalesced_sparse |
| SparseCUDA: is_coalesced_sparse |
| device_guard: False |
| |
| |
| - func: _indices(Tensor self) -> Tensor |
| variants: method |
| dispatch: |
| SparseCPU: _indices_sparse |
| SparseCUDA: _indices_sparse |
| device_guard: False |
| |
| |
| - func: _values(Tensor self) -> Tensor |
| variants: method |
| dispatch: |
| SparseCPU: _values_sparse |
| SparseCUDA: _values_sparse |
| device_guard: False |
| |
| |
| - func: hspmm_out(Tensor result, Tensor mat1, Tensor mat2) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: hspmm_out_sparse_cpu |
| SparseCUDA: hspmm_out_sparse_cuda |
| |
| - func: hspmm(Tensor mat1, Tensor mat2) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: hspmm_sparse_cpu |
| SparseCUDA: hspmm_sparse_cuda |
| |
| # This "raw copy" doesn't handle conversions NOR does it handle non-blocking. |
| - func: raw_copy_sparse_(Tensor self, Tensor src) -> Tensor |
| variants: function |
| dispatch: |
| SparseCPU: copy_sparse_ |
| SparseCUDA: copy_sparse_ |
| |
| - func: numel(Tensor self) -> int64_t |
| variants: |
| - method |
| - function |
| device_guard: False |
| |
| - func: unbind(Tensor self, int64_t dim=0) -> TensorList |
| variants: |
| - method |
| - function |
| |
| - func: native_get_device(Tensor self) -> int64_t |
| variants: function |
| dispatch: |
| SparseCUDA: get_device_sparse_cuda |
| device_guard: False |
| |
| - func: get_device(Tensor self) -> int64_t |
| device_guard: False |
| |
| - func: meshgrid(TensorList tensors) -> TensorList |
| variants: function |