disable __torch_function__ overides for operators in torch.functional (#30839)
Summary:
For now I'm just removing the decorators from all of the currently overridable functions in `torch.functional`. This means they are no longer overridable, however this should fix the benchmark regressions reported in https://github.com/pytorch/pytorch/issues/30831. Moving forward we'll be looking at reducing the overhead of the python-level override mechanism and failing that, re-implementing all of these operators in C++.
cc hl475
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30839
Differential Revision: D18838848
Pulled By: ezyang
fbshipit-source-id: 22b8015d7b2f7a947f1ebc9632c998e081b48ad8
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 8289efe..e45f7da 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -426,13 +426,9 @@
(torch.bitwise_not, lambda input, out=None: -1),
(torch.bitwise_xor, lambda input, other, out=None: -1),
(torch.bmm, lambda input, mat2, out=None: -1),
- (torch.broadcast_tensors, lambda *tensors: -1),
- (torch.cartesian_prod, lambda *tensors: -1),
(torch.cat, lambda tensors, dim=0, out=None: -1),
- (torch.cdist, lambda x1, c2, p=2, compute_mode=None: -1),
(torch.ceil, lambda input, out=None: -1),
(torch.celu, lambda input, alhpa=1., inplace=False: -1),
- (torch.chain_matmul, lambda *matrices: -1),
(torch.cholesky, lambda input, upper=False, out=None: -1),
(torch.cholesky_inverse, lambda input, upper=False, out=None: -1),
(torch.cholesky_solve, lambda input1, input2, upper=False, out=None: -1),
@@ -475,8 +471,6 @@
(torch.dsmm, lambda input, mat2: -1),
(torch.hsmm, lambda mat1, mat2: -1),
(torch.eig, lambda input, eigenvectors=False, out=None: -1),
- (torch.einsum, lambda equation, *operands: -1),
- (torch.einsum, lambda equation, *operands: -1),
(torch.embedding, lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
sparse=False: -1),
(torch.embedding_bag, lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
@@ -532,7 +526,6 @@
(torch.index_select, lambda input, dim, index, out=None: -1),
(torch.index_fill, lambda input, dim, index, value: -1),
(torch.isfinite, lambda tensor: -1),
- (torch.isinf, lambda tensor: -1),
(torch.instance_norm, lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1),
(torch.int_repr, lambda input: -1),
(torch.inverse, lambda input, out=None: -1),
@@ -564,7 +557,6 @@
(torch.lstm_cell, lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1),
(torch.lstsq, lambda input, A, out=None: -1),
(torch.lt, lambda input, other, out=None: -1),
- (torch.lu, lambda A, pivot=True, get_infos=False, out=None: -1),
(torch.lu_solve, lambda input, LU_data, LU_pivots, out=None: -1),
(torch.margin_ranking_loss, lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1),
(torch.masked_fill, lambda input, mask, value: -1),
@@ -581,7 +573,6 @@
ceil_mode=False: -1),
(torch.mean, lambda input: -1),
(torch.median, lambda input: -1),
- (torch.meshgrid, lambda *tensors, **kwargs: -1),
(torch.min, lambda input, out=None: -1),
(torch.miopen_batch_norm, lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor,
epsilon: -1),
@@ -605,7 +596,6 @@
(torch.ne, lambda input, other, out=None: -1),
(torch.neg, lambda input, out=None: -1),
(torch.nonzero, lambda input, out=None, as_tuple=False: -1),
- (torch.norm, lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1),
(torch.norm_except_dim, lambda v, pow=2, dim=0: -1),
(torch.normal, lambda mean, std, out=None: -1),
(torch.nuclear_norm, lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1),
@@ -681,7 +671,6 @@
(torch.softmax, lambda input, dim, dtype=None: -1),
(torch.solve, lambda input, A, out=None: -1),
(torch.sort, lambda input, dim=-1, descending=False, out=None: -1),
- (torch.split, lambda tensor, split_size_or_sections, dim=0: -1),
(torch.split_with_sizes, lambda tensor, split_size_or_sections, dim=0: -1),
(torch.sqrt, lambda input, out=None: -1),
(torch.squeeze, lambda input, dim=None, out=None: -1),
@@ -689,8 +678,6 @@
(torch.stack, lambda tensors, dim=0, out=None: -1),
(torch.std, lambda input: -1),
(torch.std_mean, lambda input: -1),
- (torch.stft, lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect',
- normalized=False, onesided=True: -1),
(torch.sub, lambda input, other, out=None: -1),
(torch.sum, lambda input: -1),
(torch.svd, lambda input, some=True, compute_uv=True, out=None: -1),
@@ -699,7 +686,6 @@
(torch.take, lambda input, index: -1),
(torch.tan, lambda input, out=None: -1),
(torch.tanh, lambda input, out=None: -1),
- (torch.tensordot, lambda a, b, dims=2: -1),
(torch.threshold, lambda input, threshold, value, inplace=False: -1),
(torch.topk, lambda input, k, dim=-1, descending=False, out=None: -1),
(torch.trace, lambda input: -1),
@@ -714,8 +700,6 @@
(torch.triu_indices, lambda row, col, offset=0, dtype=torch.long, device='cpu', layout=torch.strided: -1),
(torch.trunc, lambda input, out=None: -1),
(torch.unbind, lambda input, dim=0: -1),
- (torch.unique, lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1),
- (torch.unique_consecutive, lambda input, return_inverse=False, return_counts=False, dim=None: -1),
(torch.unsqueeze, lambda input, dim, out=None: -1),
(torch.var, lambda input: -1),
(torch.var_mean, lambda input: -1),
diff --git a/torch/functional.py b/torch/functional.py
index 585c36b..197e373 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -3,8 +3,6 @@
from torch._six import inf
from itertools import product
-from ._overrides import torch_function_dispatch
-
__all__ = [
'align_tensors',
'broadcast_tensors',
@@ -24,10 +22,6 @@
'unique_consecutive',
]
-def _broadcast_tensors_dispatcher(*tensors):
- return tensors
-
-@torch_function_dispatch(_broadcast_tensors_dispatcher)
def broadcast_tensors(*tensors):
r"""broadcast_tensors(*tensors) -> List of Tensors
@@ -57,11 +51,6 @@
return torch._C._VariableFunctions.broadcast_tensors(tensors)
-def _split_dispatcher(tensor, split_size_or_sections, dim=0):
- return (tensor,)
-
-
-@torch_function_dispatch(_split_dispatcher)
def split(tensor, split_size_or_sections, dim=0):
r"""Splits the tensor into chunks.
@@ -180,11 +169,6 @@
return P, L, U
-def _einsum_dispatcher(equation, *operands):
- return operands
-
-
-@torch_function_dispatch(_einsum_dispatcher)
def einsum(equation, *operands):
r"""einsum(equation, *operands) -> Tensor
@@ -258,11 +242,6 @@
return torch._C._VariableFunctions.einsum(equation, operands)
-def _isinf_dispatcher(tensor):
- return (tensor,)
-
-
-@torch_function_dispatch(_isinf_dispatcher)
def isinf(tensor):
r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not.
@@ -284,11 +263,6 @@
return tensor.abs() == inf
-def _meshgrid_dispatcher(*tensors, **kwargs):
- return tensors
-
-
-@torch_function_dispatch(_meshgrid_dispatcher)
def meshgrid(*tensors, **kwargs):
r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
@@ -326,13 +300,6 @@
return torch._C._VariableFunctions.meshgrid(tensors)
-def _stft_dispatcher(input, n_fft, hop_length=None, win_length=None,
- window=None, center=True, pad_mode='reflect',
- normalized=False, onesided=True):
- return (input,)
-
-
-@torch_function_dispatch(_stft_dispatcher)
def stft(input, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
@@ -429,12 +396,6 @@
del torch.unique_dim
-def _unique_dispatcher(input, sorted=None, return_inverse=None,
- return_counts=None, dim=None):
- return (input,)
-
-
-@torch_function_dispatch(_unique_dispatcher)
def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
r"""Returns the unique elements of the input tensor.
@@ -518,11 +479,6 @@
else:
return output
-def _unique_consecutive_dispatcher(
- input, return_inverse=None, return_counts=None, dim=None):
- return (input,)
-
-@torch_function_dispatch(_unique_consecutive_dispatcher)
def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None):
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
@@ -583,11 +539,6 @@
return output, counts
return output
-def _tensordot_dispatcher(a, b, dims=None):
- return (a, b)
-
-
-@torch_function_dispatch(_tensordot_dispatcher)
def tensordot(a, b, dims=2):
r"""Returns a contraction of a and b over multiple dimensions.
@@ -642,7 +593,6 @@
dims_b = list(range(dims))
return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
-@torch_function_dispatch(_broadcast_tensors_dispatcher)
def cartesian_prod(*tensors):
"""Do cartesian product of the given sequence of tensors. The behavior is similar to
python's `itertools.product`.
@@ -673,10 +623,6 @@
"""
return torch._C._VariableFunctions.cartesian_prod(tensors)
-def _cdist_dispatcher(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'):
- return (x1, x2)
-
-@torch_function_dispatch(_cdist_dispatcher)
def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'):
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
@@ -728,11 +674,6 @@
raise ValueError("{} is not a valid value for compute_mode".format(compute_mode))
-def _norm_dispatcher(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
- return (input,)
-
-
-@torch_function_dispatch(_norm_dispatcher)
def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
r"""Returns the matrix norm or vector norm of a given tensor.
@@ -830,11 +771,6 @@
return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out)
-def _chain_matmul_dispatcher(*matrices):
- return matrices
-
-
-@torch_function_dispatch(_chain_matmul_dispatcher)
def chain_matmul(*matrices):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
@@ -866,11 +802,6 @@
"""
return torch._C._VariableFunctions.chain_matmul(matrices)
-def _lu_dispatcher(A, pivot=None, get_infos=None, out=None):
- return (A,)
-
-
-@torch_function_dispatch(_lu_dispatcher)
def lu(A, pivot=True, get_infos=False, out=None):
r"""Computes the LU factorization of a matrix or batches of matrices
:attr:`A`. Returns a tuple containing the LU factorization and