move has_torch_function to C++, and make a special case object_has_torch_function (#48965)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48965
This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.
Test Plan: Existing unit tests. Benchmarks are in #48966
Reviewed By: ezyang
Differential Revision: D25590732
Pulled By: robieta
fbshipit-source-id: 6bd74788f06cdd673f3a2db898143d18c577eb42
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 6427a4a..dd877da 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -4,9 +4,10 @@
from torch import Tensor
from enum import Enum
from pathlib import Path
-from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
- Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
- Set, AnyStr)
+from typing import (
+ Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
+ NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
+ Generic, Set, AnyStr)
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
@@ -498,6 +499,9 @@
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
+def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
+def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
+def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
@@ -636,6 +640,8 @@
_version: _int
_base: Optional[Tensor]
grad_fn: Any
+ _grad: Optional[Tensor]
+ _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
${tensor_method_hints}
# Defined in torch/csrc/multiprocessing/init.cpp
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index f70bd1a..ca99965 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -40,6 +40,7 @@
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/tensor_dtypes.h>
+#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_layouts.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
@@ -629,6 +630,9 @@
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
{"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr},
{"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr},
+ {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
+ {"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr},
+ {"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr},
{nullptr, nullptr, 0, nullptr}
};
diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h
index cf05aec..41a2cca 100644
--- a/torch/csrc/autograd/python_variable.h
+++ b/torch/csrc/autograd/python_variable.h
@@ -24,18 +24,21 @@
bool THPVariable_initModule(PyObject *module);
THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);
-static inline bool THPVariable_CheckExact(PyObject *obj) {
+static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
// (A subclass could have different semantics.) The one exception is
// Parameter, which is used for Python bookkeeping but is equivalent to
// Tensor as far as C++ is concerned.
- auto obj_py_type = Py_TYPE(obj);
return (
- obj_py_type == (PyTypeObject*)THPVariableClass ||
- obj_py_type == (PyTypeObject*)ParameterClass
+ tp == (PyTypeObject*)THPVariableClass ||
+ tp == (PyTypeObject*)ParameterClass
);
}
+static inline bool THPVariable_CheckExact(PyObject *obj) {
+ return THPVariable_CheckTypeExact(Py_TYPE(obj));
+}
+
inline bool THPVariable_Check(PyObject *obj)
{
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp
index d3d1a3f..28414dc 100644
--- a/torch/csrc/utils/disable_torch_function.cpp
+++ b/torch/csrc/utils/disable_torch_function.cpp
@@ -1,6 +1,7 @@
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/Exceptions.h>
+#include <torch/csrc/utils/python_strings.h>
namespace torch {
static thread_local bool enable_torch_function = true;
@@ -125,3 +126,111 @@
return result;
END_HANDLE_TH_ERRORS
}
+
+// Makes sure that we don't check for __torch_function__ on basic Python types
+static bool is_basic_python_type(PyTypeObject *tp)
+{
+ return (
+ /* Basic number types */
+ tp == &PyBool_Type ||
+
+ tp == &PyLong_Type ||
+ tp == &PyFloat_Type ||
+ tp == &PyComplex_Type ||
+
+ /* Basic sequence types */
+ tp == &PyList_Type ||
+ tp == &PyTuple_Type ||
+ tp == &PyDict_Type ||
+ tp == &PySet_Type ||
+ tp == &PyFrozenSet_Type ||
+ tp == &PyUnicode_Type ||
+ tp == &PyBytes_Type ||
+
+ /* other builtins */
+ tp == &PySlice_Type ||
+ tp == Py_TYPE(Py_None) ||
+ tp == Py_TYPE(Py_Ellipsis) ||
+ tp == Py_TYPE(Py_NotImplemented) ||
+
+ PyModule_Check(tp) ||
+ /* sentinel to swallow trailing || */
+ false
+ );
+}
+
+inline bool has_torch_function_attr(PyObject* obj) {
+ auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
+ return (
+ attr.ptr() != nullptr &&
+ attr.ptr() != torch::disabled_torch_function);
+}
+
+namespace torch {
+auto check_has_torch_function(PyObject* obj) -> bool
+{
+ PyTypeObject *tp = Py_TYPE(obj);
+ return (
+ !THPVariable_CheckTypeExact(tp) &&
+ !is_basic_python_type(tp) &&
+ torch::torch_function_enabled() &&
+ has_torch_function_attr(obj)
+ );
+}
+} // namespace torch
+
+inline bool sequence_has_torch_function(PyObject* args) {
+ Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
+ for (Py_ssize_t i = 0; i < nargs; i++) {
+ PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
+ if (torch::check_has_torch_function(obj)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
+ for (Py_ssize_t i = 0; i < nargs; i++) {
+ if (torch::check_has_torch_function(args[i])) {
+ return true;
+ }
+ }
+ return false;
+}
+
+PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) {
+ bool result; // NOLINT(cppcoreguidelines-init-variables)
+ if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
+ // Fast path:
+ // If we know that we have a tuple or list, we can skip an INCREF and
+ // DECREF from PySequence_Fast. Core functions will always follow this
+ // convention (almost always tuples), and it shaves ~3.5% off the cost of
+ // the check.
+ result = sequence_has_torch_function(arg);
+ } else {
+ auto args = py::reinterpret_steal<py::object>(
+ PySequence_Fast(arg, "expected a sequence"));
+ result = sequence_has_torch_function(args.ptr());
+ }
+
+ if (result) {
+ Py_RETURN_TRUE;
+ }
+ Py_RETURN_FALSE;
+}
+
+PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) {
+ // Special case `THPModule_has_torch_function` for the single arg case.
+ if (torch::check_has_torch_function(obj)) {
+ Py_RETURN_TRUE;
+ }
+ Py_RETURN_FALSE;
+}
+
+PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) {
+ if (array_has_torch_function(args, nargs)) {
+ Py_RETURN_TRUE;
+ }
+ Py_RETURN_FALSE;
+}
diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h
index 1216660..c4f9651 100644
--- a/torch/csrc/utils/disable_torch_function.h
+++ b/torch/csrc/utils/disable_torch_function.h
@@ -9,8 +9,12 @@
bool torch_function_enabled();
PyObject* disabled_torch_function_impl();
void set_disabled_torch_function_impl(PyObject* value);
+ bool check_has_torch_function(PyObject* obj);
}
PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused);
PyObject* THPModule_DisableTorchFunctionType();
-PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
\ No newline at end of file
+PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
+PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg);
+PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj);
+PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs);
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index aa73b2a..4208f65 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -862,7 +862,7 @@
}
int i = 0;
- if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
+ if (self != nullptr && check_has_torch_function(self)) {
append_overloaded_arg(&this->overloaded_args, self);
}
for (auto& param : params) {
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 9fa4901..0f7f595 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -674,127 +674,6 @@
}
/*
- * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
- *
- * Stripped down version of PyObject_GetAttrString,
- * avoids lookups for None, tuple, and List objects,
- * and doesn't create a PyErr since this code ignores it.
- *
- * This can be much faster then PyObject_GetAttrString where
- * exceptions are not used by caller.
- *
- * 'obj' is the object to search for attribute.
- *
- * 'name' is the attribute to search for.
- *
- * Returns a py::object wrapping the return value. If the attribute lookup failed
- * the value will be NULL.
- *
- */
-
-static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
-{
- PyTypeObject *tp = Py_TYPE(obj);
- PyObject *res = (PyObject *)NULL;
-
- /* Attribute referenced by (char *)name */
- if (tp->tp_getattr != NULL) {
- res = (*tp->tp_getattr)(obj, name);
- if (res == NULL) {
- PyErr_Clear();
- }
- }
- /* Attribute referenced by (PyObject *)name */
- else if (tp->tp_getattro != NULL) {
- PyObject *w = THPUtils_internString(name);
- if (w == NULL) {
- return py::object();
- }
- res = (*tp->tp_getattro)(obj, w);
- Py_DECREF(w);
- if (res == NULL) {
- PyErr_Clear();
- }
- }
- return py::reinterpret_steal<py::object>(res);
-}
-
-// Makes sure that we don't check for __torch_function__ on basic Python types
-static bool _is_basic_python_type(PyTypeObject *tp)
-{
- return (
- /* Basic number types */
- tp == &PyBool_Type ||
-
- tp == &PyLong_Type ||
- tp == &PyFloat_Type ||
- tp == &PyComplex_Type ||
-
- /* Basic sequence types */
- tp == &PyList_Type ||
- tp == &PyTuple_Type ||
- tp == &PyDict_Type ||
- tp == &PySet_Type ||
- tp == &PyFrozenSet_Type ||
- tp == &PyUnicode_Type ||
- tp == &PyBytes_Type ||
-
- /* other builtins */
- tp == &PySlice_Type ||
- tp == Py_TYPE(Py_None) ||
- tp == Py_TYPE(Py_Ellipsis) ||
- tp == Py_TYPE(Py_NotImplemented) ||
-
- PyModule_Check(tp) ||
- /* sentinel to swallow trailing || */
- false
- );
-}
-
-/*
- * Lookup a special method, following the python approach of looking up
- * on the type object, rather than on the instance itself.
- *
- * Assumes that the special method is a torch-specific one, so does not
- * look at builtin types, nor does it look at a base Tensor.
- *
- * If no special method is found, return NULL, otherwise returns a new
- * reference to the function object
- *
- * In future, could be made more like _Py_LookupSpecial
- */
-
-static py::object PyTorch_LookupSpecial(PyObject *obj, char* name)
-{
- if (THPVariable_CheckExact(obj)) {
- return py::object();
- }
- PyTypeObject *tp = Py_TYPE(obj);
- if (_is_basic_python_type(tp)) {
- return py::object();
- }
- return PyObject_FastGetAttrString((PyObject *)tp, name);
-}
-
-/*
- * Checks if obj has a __torch_function__ implementation
- *
- * Returns true if an implementation is found and false otherwise
- *
- */
-static auto check_has_torch_function(PyObject* obj) -> bool
-{
- if (!torch_function_enabled()) {
- return false;
- }
- py::object method = PyTorch_LookupSpecial(obj, "__torch_function__");
- if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){
- return true;
- }
- return false;
-}
-
-/*
*
* Handle __torch_function__ overrides if we know that there are overloaded
* arguments. All objects stored in r.overloaded_args must have a
diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h
index 7e1cb0c..48ce9c1 100644
--- a/torch/csrc/utils/python_compat.h
+++ b/torch/csrc/utils/python_compat.h
@@ -2,6 +2,33 @@
#include <torch/csrc/python_headers.h>
+#if PY_VERSION_HEX < 0x03070000
+// METH_FASTCALL was introduced in Python 3.7, so we wrap _PyCFunctionFast
+// signatures for earlier versions.
+
+template <PyObject* (*f)(PyObject*, PyObject *const *, Py_ssize_t)>
+PyObject* maybe_wrap_fastcall(PyObject *module, PyObject *args) {
+ return f(
+ module,
+
+ // _PyTuple_ITEMS
+ // Because this is only a compat shim for Python 3.6, we don't have
+ // to worry about the representation changing.
+ ((PyTupleObject *)args)->ob_item,
+ PySequence_Fast_GET_SIZE(args)
+ );
+}
+
+#define MAYBE_METH_FASTCALL METH_VARARGS
+#define MAYBE_WRAP_FASTCALL(f) maybe_wrap_fastcall<f>
+
+#else
+
+#define MAYBE_METH_FASTCALL METH_FASTCALL
+#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f
+
+#endif
+
// PyPy 3.6 does not yet have PySlice_Unpack
#if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION)
diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h
index 55cf0c3..66e5bf1 100644
--- a/torch/csrc/utils/python_strings.h
+++ b/torch/csrc/utils/python_strings.h
@@ -4,6 +4,7 @@
#include <stdexcept>
#include <string>
#include <torch/csrc/utils/object_ptr.h>
+#include <torch/csrc/utils/pybind.h>
// Utilities for handling Python strings. Note that PyString, when defined, is
// the same as PyBytes.
@@ -54,3 +55,49 @@
inline void THPUtils_internStringInPlace(PyObject** obj) {
PyUnicode_InternInPlace(obj);
}
+
+/*
+ * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
+ *
+ * Stripped down version of PyObject_GetAttrString,
+ * avoids lookups for None, tuple, and List objects,
+ * and doesn't create a PyErr since this code ignores it.
+ *
+ * This can be much faster then PyObject_GetAttrString where
+ * exceptions are not used by caller.
+ *
+ * 'obj' is the object to search for attribute.
+ *
+ * 'name' is the attribute to search for.
+ *
+ * Returns a py::object wrapping the return value. If the attribute lookup failed
+ * the value will be NULL.
+ *
+ */
+
+static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
+{
+ PyTypeObject *tp = Py_TYPE(obj);
+ PyObject *res = (PyObject *)nullptr;
+
+ /* Attribute referenced by (char *)name */
+ if (tp->tp_getattr != nullptr) {
+ res = (*tp->tp_getattr)(obj, name);
+ if (res == nullptr) {
+ PyErr_Clear();
+ }
+ }
+ /* Attribute referenced by (PyObject *)name */
+ else if (tp->tp_getattro != nullptr) {
+ auto w = py::reinterpret_steal<py::object>(
+ THPUtils_internString(name));
+ if (w.ptr() == nullptr) {
+ return py::object();
+ }
+ res = (*tp->tp_getattro)(obj, w.ptr());
+ if (res == nullptr) {
+ PyErr_Clear();
+ }
+ }
+ return py::reinterpret_steal<py::object>(res);
+}
diff --git a/torch/functional.py b/torch/functional.py
index 1442ab5..43fa0a3 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -6,7 +6,9 @@
import torch.nn.functional as F
from torch.types import _size
from ._lowrank import svd_lowrank, pca_lowrank
-from .overrides import has_torch_function, handle_torch_function
+from .overrides import (
+ has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+ handle_torch_function)
from ._jit_internal import boolean_dispatch, List
from ._jit_internal import _overload as overload
from torch._autograd_functions import _LU
@@ -68,7 +70,7 @@
[0, 1, 2]])
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(broadcast_tensors, tensors, *tensors)
return _VF.broadcast_tensors(tensors) # type: ignore
@@ -146,7 +148,7 @@
[8, 9]]))
"""
if not torch.jit.is_scripting():
- if type(tensor) is not Tensor and has_torch_function((tensor,)):
+ if has_torch_function_unary(tensor):
return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,
dim=dim)
# Overwriting reason:
@@ -235,10 +237,9 @@
tensor(2.9802e-08)
"""
if not torch.jit.is_scripting():
- tens_ops = (LU_data, LU_pivots)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(LU_data, LU_pivots):
return handle_torch_function(
- lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,
+ lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, unpack_data=unpack_data,
unpack_pivots=unpack_pivots)
shape = LU_data.shape
# In generalized LU factorization, the following shape relations hold:
@@ -398,7 +399,7 @@
[ 0.3311, 5.5201, -3.0356]])
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):
+ if has_torch_function(operands):
return handle_torch_function(einsum, operands, equation, *operands)
if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
# the old interface of passing the operands as one list argument
@@ -448,7 +449,7 @@
def _meshgrid(*tensors):
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(meshgrid, tensors, *tensors)
if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
# the old interface of passing the operands as one list argument
@@ -568,7 +569,7 @@
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
@@ -650,7 +651,7 @@
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided,
@@ -734,7 +735,7 @@
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
@@ -810,7 +811,7 @@
tensor([2, 2, 1, 2, 1])
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
unique_consecutive, (input,), input, return_inverse=return_inverse,
return_counts=return_counts, dim=dim)
@@ -823,7 +824,7 @@
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -834,7 +835,7 @@
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -845,7 +846,7 @@
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -888,7 +889,7 @@
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -899,7 +900,7 @@
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -910,7 +911,7 @@
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -1000,7 +1001,7 @@
[ -0.2850, 4.2573, -3.5997]])
"""
if not torch.jit.is_scripting():
- if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):
+ if has_torch_function_variadic(a, b):
return handle_torch_function(tensordot, (a, b), a, b, dims=dims)
if isinstance(dims, (list, tuple)) or \
(isinstance(dims, torch.Tensor) and dims.numel() > 1):
@@ -1046,7 +1047,7 @@
[3, 5]])
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(cartesian_prod, tensors, *tensors)
return _VF.cartesian_prod(tensors) # type: ignore
@@ -1080,7 +1081,7 @@
[0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
"""
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(block_diag, tensors, *tensors)
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore
@@ -1128,7 +1129,7 @@
[2.2830, 0.3791]])
"""
if not torch.jit.is_scripting():
- if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):
+ if has_torch_function_variadic(x1, x2):
return handle_torch_function(
cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
@@ -1168,7 +1169,7 @@
(tensor([0.5000]), tensor([1.]))
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(atleast_1d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
@@ -1203,7 +1204,7 @@
(tensor([[0.5000]]), tensor([[1.]]))
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(atleast_2d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
@@ -1247,7 +1248,7 @@
(tensor([[[0.5000]]]), tensor([[[1.]]]))
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+ if has_torch_function(tensors):
return handle_torch_function(atleast_3d, tensors, *tensors)
if len(tensors) == 1:
tensors = tensors[0]
@@ -1380,7 +1381,7 @@
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
@@ -1476,7 +1477,7 @@
.. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
"""
if not torch.jit.is_scripting():
- if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):
+ if has_torch_function(matrices):
return handle_torch_function(chain_matmul, matrices, *matrices)
return _VF.chain_matmul(matrices) # type: ignore
@@ -1596,7 +1597,7 @@
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
if not torch.jit.is_scripting():
- if type(A) is not Tensor and has_torch_function((A,)):
+ if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
result = _lu_impl(A, pivot, get_infos, out)
@@ -1612,7 +1613,7 @@
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
# need to check for torch_function here so that we exit if
if not torch.jit.is_scripting():
- if type(A) is not Tensor and has_torch_function((A,)):
+ if has_torch_function_unary(A):
return handle_torch_function(
lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
result = _lu_impl(A, pivot, get_infos, out)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 162c22ea..2cfc1c2 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -9,12 +9,15 @@
from torch._torch_docs import reproducibility_notes, tf32_notes
from .._jit_internal import boolean_dispatch, _overload
-from ..overrides import has_torch_function, handle_torch_function
+from ..overrides import (
+ has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+ handle_torch_function)
from . import _reduction as _Reduction
from . import grad # noqa: F401
from .modules import utils
from .modules.utils import _single, _pair, _triple, _list_with_default
+
Tensor = torch.Tensor
conv1d = _add_docstr(
@@ -410,7 +413,7 @@
http://arxiv.org/abs/1412.6071
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
fractional_max_pool2d_with_indices,
(input,),
@@ -438,7 +441,7 @@
):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
fractional_max_pool2d,
(input,),
@@ -500,7 +503,7 @@
http://arxiv.org/abs/1412.6071
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
fractional_max_pool3d_with_indices,
(input,),
@@ -532,7 +535,7 @@
):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
fractional_max_pool3d,
(input,),
@@ -569,7 +572,7 @@
See :class:`~torch.nn.MaxPool1d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool1d_with_indices,
(input,),
@@ -589,7 +592,7 @@
def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool1d,
(input,),
@@ -627,7 +630,7 @@
See :class:`~torch.nn.MaxPool2d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool2d_with_indices,
(input,),
@@ -647,7 +650,7 @@
def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool2d,
(input,),
@@ -685,7 +688,7 @@
See :class:`~torch.nn.MaxPool3d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool3d_with_indices,
(input,),
@@ -705,7 +708,7 @@
def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
# type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_pool3d,
(input,),
@@ -773,7 +776,7 @@
See :class:`~torch.nn.MaxUnpool1d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_unpool1d,
(input,),
@@ -805,7 +808,7 @@
See :class:`~torch.nn.MaxUnpool2d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_unpool2d,
(input,),
@@ -833,7 +836,7 @@
See :class:`~torch.nn.MaxUnpool3d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
max_unpool3d,
(input,),
@@ -863,7 +866,7 @@
See :class:`~torch.nn.LPPool2d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
)
@@ -885,7 +888,7 @@
See :class:`~torch.nn.LPPool1d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
)
@@ -909,7 +912,7 @@
return_indices: whether to return pooling indices. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices
)
@@ -919,7 +922,7 @@
def _adaptive_max_pool1d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList1[int], bool) -> Tensor
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices
)
@@ -950,7 +953,7 @@
return_indices: whether to return pooling indices. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices
)
@@ -961,7 +964,7 @@
def _adaptive_max_pool2d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList2[int], bool) -> Tensor
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices
)
@@ -992,7 +995,7 @@
return_indices: whether to return pooling indices. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices
)
@@ -1003,7 +1006,7 @@
def _adaptive_max_pool3d(input, output_size, return_indices=False):
# type: (Tensor, BroadcastingList3[int], bool) -> Tensor
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices
)
@@ -1050,7 +1053,7 @@
double-integer tuple)
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@@ -1069,7 +1072,7 @@
triple-integer tuple)
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
@@ -1090,7 +1093,7 @@
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace)
if p < 0.0 or p > 1.0:
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1103,7 +1106,7 @@
See :class:`~torch.nn.AlphaDropout` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace)
if p < 0.0 or p > 1.0:
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1126,7 +1129,7 @@
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace)
if p < 0.0 or p > 1.0:
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1151,7 +1154,7 @@
# This is 100% the same code as dropout2d. We duplicate this code so that
# stack traces are not confusing.
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace)
if p < 0.0 or p > 1.0:
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1179,7 +1182,7 @@
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace
)
@@ -1194,7 +1197,7 @@
See :class:`~torch.nn.Threshold` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace)
if inplace:
result = _VF.threshold_(input, threshold, value)
@@ -1225,7 +1228,7 @@
:class:`~torch.nn.ReLU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(relu, (input,), input, inplace=inplace)
if inplace:
result = torch.relu_(input)
@@ -1263,7 +1266,7 @@
dim (int): dimension on which to split the input. Default: -1
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(glu, (input,), input, dim=dim)
if input.dim() == 0:
raise RuntimeError("glu does not support scalars because halving size must be even")
@@ -1278,7 +1281,7 @@
details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace)
if inplace:
result = torch._C._nn.hardtanh_(input, min_val, max_val)
@@ -1305,7 +1308,7 @@
See :class:`~torch.nn.ReLU6` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(relu6, (input,), input, inplace=inplace)
return hardtanh(input, 0.0, 6.0, inplace)
@@ -1317,7 +1320,7 @@
See :class:`~torch.nn.ELU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace)
if inplace:
result = torch._C._nn.elu_(input, alpha)
@@ -1347,7 +1350,7 @@
See :class:`~torch.nn.SELU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(selu, (input,), input, inplace=inplace)
if inplace:
result = torch.selu_(input)
@@ -1375,7 +1378,7 @@
See :class:`~torch.nn.CELU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace)
if inplace:
result = torch.celu_(input, alpha)
@@ -1404,7 +1407,7 @@
See :class:`~torch.nn.LeakyReLU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace)
if inplace:
result = torch._C._nn.leaky_relu_(input, negative_slope)
@@ -1433,7 +1436,7 @@
See :class:`~torch.nn.PReLU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(prelu, (input,), input, weight)
return torch.prelu(input, weight)
@@ -1448,7 +1451,7 @@
See :class:`~torch.nn.RReLU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace
)
@@ -1491,7 +1494,7 @@
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(gelu, (input,), input)
return torch._C._nn.gelu(input)
@@ -1505,7 +1508,7 @@
See :class:`~torch.nn.Hardshrink` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(hardshrink, (input,), input, lambd=lambd)
return torch.hardshrink(input, lambd)
@@ -1518,7 +1521,7 @@
See :class:`~torch.nn.Tanhshrink` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(tanhshrink, (input,), input)
return input - input.tanh()
@@ -1531,7 +1534,7 @@
See :class:`~torch.nn.Softsign` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(softsign, (input,), input)
return input / (input.abs() + 1)
@@ -1580,7 +1583,7 @@
is performed. This is useful for preventing data type overflows. Default: None.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("softmin", input.dim(), _stacklevel)
@@ -1617,7 +1620,7 @@
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
@@ -1669,7 +1672,7 @@
https://arxiv.org/abs/1611.01144
"""
if not torch.jit.is_scripting():
- if type(logits) is not Tensor and has_torch_function((logits,)):
+ if has_torch_function_unary(logits):
return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim)
if eps != 1e-10:
warnings.warn("`eps` parameter is deprecated and has no effect.")
@@ -1708,7 +1711,7 @@
is performed. This is useful for preventing data type overflows. Default: None.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
@@ -1772,7 +1775,7 @@
See :class:`~torch.nn.Hardsigmoid` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
if inplace:
return torch._C._nn.hardsigmoid_(input)
@@ -1793,10 +1796,9 @@
- Bias: :math:`(out\_features)`
- Output: :math:`(N, *, out\_features)`
"""
- tens_ops = (input, weight)
if not torch.jit.is_scripting():
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
- return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
+ if has_torch_function_variadic(input, weight):
+ return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
@@ -1845,7 +1847,7 @@
See :class:`~torch.nn.SiLU` for more details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(silu, (input,), input, inplace=inplace)
if inplace:
return torch._C._nn.silu_(input)
@@ -1870,7 +1872,7 @@
https://arxiv.org/abs/1905.02244
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(hardswish, (input,), input, inplace=inplace)
if inplace:
return torch._C._nn.hardswish_(input)
@@ -2058,11 +2060,10 @@
"""
if not torch.jit.is_scripting():
- tens_ops = (input, weight)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, weight):
return handle_torch_function(
embedding_bag,
- tens_ops,
+ (input, weight),
input,
weight,
offsets=offsets,
@@ -2188,7 +2189,7 @@
:class:`~torch.nn.BatchNorm3d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
batch_norm,
(input,),
@@ -2227,7 +2228,7 @@
:class:`~torch.nn.InstanceNorm3d` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
instance_norm,
(input,),
@@ -2258,7 +2259,7 @@
See :class:`~torch.nn.LayerNorm` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps
)
@@ -2273,7 +2274,7 @@
See :class:`~torch.nn.GroupNorm` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps)
_verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))
return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
@@ -2287,7 +2288,7 @@
See :class:`~torch.nn.LocalResponseNorm` for details.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k)
dim = input.dim()
if dim < 3:
@@ -2425,11 +2426,10 @@
>>> output.backward()
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
nll_loss,
- tens_ops,
+ (input, target),
input,
target,
weight=weight,
@@ -2522,11 +2522,10 @@
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
poisson_nll_loss,
- tens_ops,
+ (input, target),
input,
target,
log_input=log_input,
@@ -2593,11 +2592,10 @@
In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
kl_div,
- tens_ops,
+ (input, target),
input,
target,
size_average=size_average,
@@ -2679,11 +2677,10 @@
>>> loss.backward()
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
cross_entropy,
- tens_ops,
+ (input, target),
input,
target,
weight=weight,
@@ -2739,11 +2736,10 @@
>>> loss.backward()
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
binary_cross_entropy,
- tens_ops,
+ (input, target),
input,
target,
weight=weight,
@@ -2813,11 +2809,10 @@
>>> loss.backward()
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
binary_cross_entropy_with_logits,
- tens_ops,
+ (input, target),
input,
target,
weight=weight,
@@ -2851,11 +2846,10 @@
See :class:`~torch.nn.SmoothL1Loss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
smooth_l1_loss,
- tens_ops,
+ (input, target),
input,
target,
size_average=size_average,
@@ -2891,10 +2885,9 @@
See :class:`~torch.nn.L1Loss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
- l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+ l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
)
if not (target.size() == input.size()):
warnings.warn(
@@ -2924,10 +2917,9 @@
See :class:`~torch.nn.MSELoss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
- mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+ mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
)
if not (target.size() == input.size()):
warnings.warn(
@@ -2957,11 +2949,10 @@
See :class:`~torch.nn.MarginRankingLoss` for details.
""" # noqa
if not torch.jit.is_scripting():
- tens_ops = (input1, input2, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input1, input2, target):
return handle_torch_function(
margin_ranking_loss,
- tens_ops,
+ (input1, input2, target),
input1,
input2,
target,
@@ -2997,11 +2988,10 @@
See :class:`~torch.nn.HingeEmbeddingLoss` for details.
""" # noqa
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
hinge_embedding_loss,
- tens_ops,
+ (input, target),
input,
target,
margin=margin,
@@ -3028,11 +3018,10 @@
See :class:`~torch.nn.MultiLabelMarginLoss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
multilabel_margin_loss,
- tens_ops,
+ (input, target),
input,
target,
size_average=size_average,
@@ -3058,10 +3047,9 @@
See :class:`~torch.nn.SoftMarginLoss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
- soft_margin_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+ soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
)
if size_average is not None or reduce is not None:
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
@@ -3083,11 +3071,10 @@
See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
multilabel_soft_margin_loss,
- tens_ops,
+ (input, target),
input,
target,
weight=weight,
@@ -3131,11 +3118,10 @@
See :class:`~torch.nn.CosineEmbeddingLoss` for details.
""" # noqa
if not torch.jit.is_scripting():
- tens_ops = (input1, input2, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input1, input2, target):
return handle_torch_function(
cosine_embedding_loss,
- tens_ops,
+ (input1, input2, target),
input1,
input2,
target,
@@ -3167,11 +3153,10 @@
See :class:`~torch.nn.MultiMarginLoss` for details.
"""
if not torch.jit.is_scripting():
- tens_ops = (input, target)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, target):
return handle_torch_function(
multi_margin_loss,
- tens_ops,
+ (input, target),
input,
target,
p=p,
@@ -3460,7 +3445,7 @@
{backward_reproducibility_note}
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
interpolate,
(input,),
@@ -3824,10 +3809,9 @@
.. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908
"""
if not torch.jit.is_scripting():
- tens_ops = (input, grid)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(input, grid):
return handle_torch_function(
- grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
+ grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
)
if mode != "bilinear" and mode != "nearest" and mode != "bicubic":
raise ValueError(
@@ -3916,7 +3900,7 @@
(the center of the input image).
"""
if not torch.jit.is_scripting():
- if type(theta) is not Tensor and has_torch_function((theta,)):
+ if has_torch_function_unary(theta):
return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners)
if align_corners is None:
warnings.warn(
@@ -4025,7 +4009,7 @@
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value)
assert len(pad) % 2 == 0, "Padding length must be divisible by 2"
assert len(pad) // 2 <= input.dim(), "Padding length too large"
@@ -4208,11 +4192,10 @@
See :class:`~torch.nn.TripletMarginLoss` for details
"""
if not torch.jit.is_scripting():
- tens_ops = (anchor, positive, negative)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(anchor, positive, negative):
return handle_torch_function(
triplet_margin_loss,
- tens_ops,
+ (anchor, positive, negative),
anchor,
positive,
negative,
@@ -4250,11 +4233,10 @@
"functions requiring Callables cannot be scripted."
)
- tens_ops = (anchor, positive, negative)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function_variadic(anchor, positive, negative):
return handle_torch_function(
triplet_margin_with_distance_loss,
- tens_ops,
+ (anchor, positive, negative),
anchor,
positive,
negative,
@@ -4304,7 +4286,7 @@
operation won't be differentiable.
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out)
if out is None:
denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
@@ -4337,7 +4319,7 @@
See :class:`torch.nn.Unfold` for details
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride
)
@@ -4365,7 +4347,7 @@
See :class:`torch.nn.Fold` for details
"""
if not torch.jit.is_scripting():
- if type(input) is not Tensor and has_torch_function((input,)):
+ if has_torch_function_unary(input):
return handle_torch_function(
fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
)
@@ -4633,7 +4615,7 @@
"""
if not torch.jit.is_scripting():
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
- if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
diff --git a/torch/overrides.py b/torch/overrides.py
index 889bf5b..96c7f82 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -28,7 +28,9 @@
from typing import Dict, Set, List, Any, Callable, Iterable
import torch
-from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl
+from torch._C import (
+ _has_torch_function, _has_torch_function_unary,
+ _has_torch_function_variadic, _add_docstr)
__all__ = [
"get_ignored_functions",
@@ -167,6 +169,8 @@
torch.nn.functional.upsample_bilinear,
torch.nn.functional.upsample_nearest,
torch.nn.functional.has_torch_function,
+ torch.nn.functional.has_torch_function_unary,
+ torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function,
torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid,
@@ -1193,33 +1197,52 @@
'__torch_function__: {}'
.format(func_name, [type(arg) for arg in overloaded_args]))
-def has_torch_function(relevant_args: Iterable[Any]) -> bool:
- """Check for __torch_function__ implementations in the elements of an iterable.
-
- Considers exact ``Tensor`` s non-dispatchable.
-
+has_torch_function = _add_docstr(
+ _has_torch_function,
+ r"""Check for __torch_function__ implementations in the elements of an iterable.
+ Considers exact ``Tensor`` s and ``Parameter`` s non-dispatchable.
Arguments
---------
relevant_args : iterable
Iterable or aguments to check for __torch_function__ methods.
-
Returns
-------
bool
True if any of the elements of relevant_args have __torch_function__
implementations, False otherwise.
-
See Also
________
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``.
"""
- return _is_torch_function_enabled() and any(
- type(a) is not torch.Tensor and
- getattr(a, '__torch_function__', _disabled_torch_function_impl)
- is not _disabled_torch_function_impl
- for a in relevant_args
- )
+)
+
+has_torch_function_unary = _add_docstr(
+ _has_torch_function_unary,
+ r"""Special case of `has_torch_function` for single inputs.
+ Instead of:
+ `has_torch_function((t,))`
+ call:
+ `has_torch_function_unary(t)`
+ which skips unnecessary packing and unpacking work.
+ """
+)
+
+has_torch_function_variadic = _add_docstr(
+ _has_torch_function_variadic,
+ r"""Special case of `has_torch_function` that skips tuple creation.
+
+ This uses the METH_FASTCALL protocol introduced in Python 3.7; for 3.6
+ and before it has roughly equivilent performance compared to
+ `has_torch_function`.
+
+ Instead of:
+ `has_torch_function((a, b))`
+ call:
+ `has_torch_function_variadic(a, b)`
+ which skips unnecessary packing and unpacking work.
+ """
+)
@functools.lru_cache(None)
def get_overridable_functions() -> Dict[Any, List[Callable]]:
diff --git a/torch/tensor.py b/torch/tensor.py
index bd8f347..eedffc7 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -10,7 +10,9 @@
from torch._namedtensor_internals import (
update_names, check_serializing_named_tensor, resolve_ellipsis,
unzip_namedshape, single_ellipsis_index, is_ellipsis)
-from torch.overrides import has_torch_function, handle_torch_function
+from torch.overrides import (
+ has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+ handle_torch_function)
import torch.utils.hooks as hooks
@@ -21,7 +23,7 @@
@functools.wraps(f, assigned=assigned)
def wrapped(*args, **kwargs):
- if not all(type(t) is Tensor for t in args) and has_torch_function(args):
+ if has_torch_function(args):
return handle_torch_function(wrapped, args, *args, **kwargs)
try:
return f(*args, **kwargs)
@@ -39,9 +41,8 @@
# otherwise, it will not show up in autocomplete.
class Tensor(torch._C._TensorBase):
def __deepcopy__(self, memo):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__deepcopy__, relevant_args, self, memo)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
if not self.is_leaf:
raise RuntimeError("Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
@@ -80,9 +81,8 @@
return new_tensor
def __reduce_ex__(self, proto):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
check_serializing_named_tensor(self)
# See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self)
@@ -148,9 +148,8 @@
return (torch._utils._rebuild_tensor_v2, args)
def __setstate__(self, state):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__setstate__, relevant_args, self, state)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__setstate__, (self,), self, state)
# Warning: this method is NOT called when you torch.load() a tensor;
# that is managed by _rebuild_tensor_v2
if not self.is_leaf:
@@ -168,9 +167,8 @@
self.requires_grad, _, self._backward_hooks = state
def __repr__(self):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__repr__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__repr__, (self,), self)
# All strings are unicode in Python 3.
return torch._tensor_str._str(self)
@@ -215,11 +213,10 @@
used to compute the attr::tensors. All the provided inputs must be leaf
Tensors.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
return handle_torch_function(
Tensor.backward,
- relevant_args,
+ (self,),
self,
gradient=gradient,
retain_graph=retain_graph,
@@ -256,9 +253,8 @@
>>> h.remove() # removes the hook
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.register_hook, relevant_args, self, hook)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.register_hook, (self,), self, hook)
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a tensor that "
"doesn't require gradient")
@@ -324,9 +320,8 @@
def retain_grad(self):
r"""Enables .grad attribute for non-leaf Tensors."""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.retain_grad, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.retain_grad, (self,), self)
if not self.requires_grad:
raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
if self.is_leaf: # no-op for leaves
@@ -355,9 +350,8 @@
This is always ``True`` for CUDA tensors.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.is_shared, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.is_shared, (self,), self)
return self.storage().is_shared()
def share_memory_(self):
@@ -366,17 +360,15 @@
This is a no-op if the underlying storage is already in shared memory
and for CUDA tensors. Tensors in shared memory cannot be resized.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.share_memory_, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.share_memory_, (self,), self)
self.storage().share_memory_()
return self
def __reversed__(self):
r"""Reverses the tensor along dimension 0."""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__reversed__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__reversed__, (self,), self)
if self.dim() == 0:
return self
else:
@@ -384,17 +376,15 @@
def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
r"""See :func:`torch.norm`"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.norm, relevant_args, self, p=p, dim=dim, keepdim=keepdim, dtype=dtype)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype)
return torch.norm(self, p, dim, keepdim, dtype=dtype)
def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)
if not torch._jit_internal.is_scripting():
if self.requires_grad:
@@ -434,10 +424,9 @@
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
return handle_torch_function(
- Tensor.stft, relevant_args, self, n_fft, hop_length=hop_length,
+ Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex
)
@@ -450,10 +439,9 @@
onesided: Optional[bool] = None, length: Optional[int] = None,
return_complex: bool = False):
r"""See :func:`torch.istft`"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
return handle_torch_function(
- Tensor.istft, relevant_args, self, n_fft, hop_length=hop_length, win_length=win_length,
+ Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
return_complex=return_complex
)
@@ -461,17 +449,15 @@
normalized, onesided, length, return_complex=return_complex)
def resize(self, *sizes):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.resize, relevant_args, self, *sizes)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.resize, (self,), self, *sizes)
warnings.warn("non-inplace resize is deprecated")
from torch.autograd._functions import Resize
return Resize.apply(self, sizes)
def resize_as(self, tensor):
- relevant_args = (self, tensor)
- if type(self) is not Tensor and type(tensor) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.resize_as, relevant_args, self, tensor)
+ if has_torch_function_variadic(self, tensor):
+ return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
warnings.warn("non-inplace resize_as is deprecated")
from torch.autograd._functions import Resize
return Resize.apply(self, tensor.size())
@@ -479,9 +465,8 @@
def split(self, split_size, dim=0):
r"""See :func:`torch.split`
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.split, relevant_args, self, split_size, dim=dim)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.split, (self,), self, split_size, dim=dim)
if isinstance(split_size, int):
return super(Tensor, self).split(split_size, dim)
elif isinstance(split_size, Tensor):
@@ -498,10 +483,9 @@
See :func:`torch.unique`
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
return handle_torch_function(
- Tensor.unique, relevant_args, self, sorted=sorted, return_inverse=return_inverse,
+ Tensor.unique, (self,), self, sorted=sorted, return_inverse=return_inverse,
return_counts=return_counts, dim=dim
)
return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
@@ -511,24 +495,21 @@
See :func:`torch.unique_consecutive`
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
return handle_torch_function(
- Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse,
+ Tensor.unique_consecutive, (self,), self, return_inverse=return_inverse,
return_counts=return_counts, dim=dim
)
return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
def __rsub__(self, other):
- relevant_args = (self, other)
- if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__rsub__, relevant_args, self, other)
+ if has_torch_function_variadic(self, other):
+ return handle_torch_function(Tensor.__rsub__, (self, other), self, other)
return _C._VariableFunctions.rsub(self, other)
def __rdiv__(self, other):
- relevant_args = (self, other)
- if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other)
+ if has_torch_function_variadic(self, other):
+ return handle_torch_function(Tensor.__rdiv__, (self, other), self, other)
return self.reciprocal() * other
__rtruediv__ = __rdiv__
@@ -537,17 +518,15 @@
__pow__ = _C._TensorBase.pow
def __format__(self, format_spec):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__format__, relevant_args, self, format_spec)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
if self.dim() == 0:
return self.item().__format__(format_spec)
return object.__format__(self, format_spec)
def __ipow__(self, other): # type: ignore[misc]
- relevant_args = (self, other)
- if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__ipow__, relevant_args, self, other)
+ if has_torch_function_variadic(self, other):
+ return handle_torch_function(Tensor.__ipow__, (self, other), self, other)
return NotImplemented
@_wrap_type_error_to_not_implemented
@@ -567,9 +546,8 @@
__abs__ = _C._TensorBase.abs
def __len__(self):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__len__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__len__, (self,), self)
if self.dim() == 0:
raise TypeError("len() of a 0-d tensor")
return self.shape[0]
@@ -581,9 +559,8 @@
# (e.g., if you zip(*hiddens), the eager map will force all the
# indexes of hiddens[0] before hiddens[1], while the generator
# map will interleave them.)
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__iter__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__iter__, (self,), self)
if self.dim() == 0:
raise TypeError('iteration over a 0-d tensor')
if torch._C._get_tracing_state():
@@ -594,15 +571,13 @@
return iter(self.unbind(0))
def __hash__(self):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__hash__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__hash__, (self,), self)
return id(self)
def __dir__(self):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__dir__, relevant_args, self)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__dir__, (self,), self)
if self.is_quantized:
warnings.warn('Only a small subset of methods are supported for quantized tensors.')
tensor_methods = dir(self.__class__)
@@ -620,9 +595,8 @@
__array_priority__ = 1000 # prefer Tensor ops over numpy ones
def __array__(self, dtype=None):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__array__, relevant_args, self, dtype=dtype)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
if dtype is None:
return self.numpy()
else:
@@ -631,9 +605,8 @@
# Wrap Numpy array again in a suitable tensor when done, to support e.g.
# `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
def __array_wrap__(self, array):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__array_wrap__, relevant_args, self, array=array)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__array_wrap__, (self,), self, array=array)
if array.dtype == bool:
# Workaround, torch has no built-in bool tensor
array = array.astype('uint8')
@@ -646,9 +619,8 @@
element (Tensor or scalar): element to be checked
for presence in current tensor"
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.__contains__, (self,), self, element)
if isinstance(element, (torch.Tensor, Number)):
# type hint doesn't understand the __contains__ result array
return (element == self).any().item() # type: ignore[union-attr]
@@ -665,10 +637,9 @@
See:
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
- return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined]
+ return handle_torch_function(Tensor.__cuda_array_interface__.__get__, (self,), self) # type: ignore[attr-defined]
# raise AttributeError for unsupported tensors, so that
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
@@ -761,9 +732,8 @@
The named tensor API is experimental and subject to change.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.refine_names, relevant_args, self, *names)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.refine_names, (self,), self, *names)
names = resolve_ellipsis(names, self.names, 'refine_names')
return super(Tensor, self).refine_names(names)
@@ -803,9 +773,8 @@
The named tensor API is experimental and subject to change.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.align_to, relevant_args, self, *names)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.align_to, (self,), self, *names)
ellipsis_idx = single_ellipsis_index(names, 'align_to')
if ellipsis_idx is None:
return super(Tensor, self).align_to(names)
@@ -839,9 +808,8 @@
.. warning::
The named tensor API is experimental and subject to change.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.unflatten, relevant_args, self, dim, sizes)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
if not sizes:
raise RuntimeError("unflatten: sizes must be non-empty")
@@ -855,9 +823,8 @@
def rename_(self, *names, **rename_map):
"""In-place version of :meth:`~Tensor.rename`."""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.rename_, relevant_args, self, *names, **rename_map)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.rename_, (self,), self, *names, **rename_map)
# Note [rename_ / rename API]
# The Python API for these is different from the C++ API. In Python:
@@ -900,17 +867,15 @@
The named tensor API is experimental and subject to change.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor.rename, relevant_args, self, *names, **rename_map)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor.rename, (self,), self, *names, **rename_map)
# See Note [rename_ / rename API]
return update_names(self, names, rename_map, inplace=False)
def _update_names(self, names, inplace):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
- return handle_torch_function(Tensor._update_names, relevant_args, self, names, inplace)
+ if has_torch_function_unary(self):
+ return handle_torch_function(Tensor._update_names, (self,), self, names, inplace)
# See Note [rename_ / rename API]
if inplace:
@@ -926,10 +891,9 @@
The attribute will then contain the gradients computed and future calls to
:func:`backward` will accumulate (add) gradients into it.
"""
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
- return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined]
+ return handle_torch_function(Tensor.grad.__get__, (self,), self) # type: ignore[attr-defined]
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
@@ -941,18 +905,16 @@
@grad.setter
def grad(self, new_grad):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
- return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined]
+ return handle_torch_function(Tensor.grad.__set__, (self,), self, new_grad) # type: ignore[attr-defined]
self._grad = new_grad
@grad.deleter
def grad(self):
- relevant_args = (self,)
- if type(self) is not Tensor and has_torch_function(relevant_args):
+ if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
- return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined]
+ return handle_torch_function(Tensor.grad.__delete__, (self,), self) # type: ignore[attr-defined]
del self._grad
@classmethod