move has_torch_function to C++, and make a special case object_has_torch_function (#48965)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48965

This PR pulls `__torch_function__` checking entirely into C++, and adds a special `object_has_torch_function` method for ops which only have one arg as this lets us skip tuple construction and unpacking. We can now also do away with the Python side fast bailout for `Tensor` (e.g. `if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors)`) because they're actually slower than checking with the Python C API.

Test Plan: Existing unit tests. Benchmarks are in #48966

Reviewed By: ezyang

Differential Revision: D25590732

Pulled By: robieta

fbshipit-source-id: 6bd74788f06cdd673f3a2db898143d18c577eb42
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 6427a4a..dd877da 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -4,9 +4,10 @@
 from torch import Tensor
 from enum import Enum
 from pathlib import Path
-from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
-                    Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
-                    Set, AnyStr)
+from typing import (
+    Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
+    NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
+    Generic, Set, AnyStr)
 from torch._six import inf
 
 from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
@@ -498,6 +499,9 @@
 def _set_qengine(qegine: _int) -> None: ...  # THPModule_setQEngine
 def _supported_qengines() -> List[_int]: ...  # THPModule_supportedQEngines
 def _is_xnnpack_enabled() -> _bool: ...  # THPModule_isEnabledXNNPACK
+def _has_torch_function(args: Iterable[Any]) -> _bool: ...  # THPModule_has_torch_function
+def _has_torch_function_unary(Any) -> _bool: ...  # THPModule_has_torch_function_unary
+def _has_torch_function_variadic(*args: Any) -> _bool: ...  # THPModule_has_torch_function_variadic
 def _vmapmode_increment_nesting() -> _int: ...  # THPModule_vmapmode_increment_nesting
 def _vmapmode_decrement_nesting() -> _int: ...  # THPModule_vmapmode_decrement_nesting
 def _log_api_usage_once(str) -> None: ...  # LogAPIUsageOnceFromPython
@@ -636,6 +640,8 @@
     _version: _int
     _base: Optional[Tensor]
     grad_fn: Any
+    _grad: Optional[Tensor]
+    _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
     ${tensor_method_hints}
 
 # Defined in torch/csrc/multiprocessing/init.cpp
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index f70bd1a..ca99965 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -40,6 +40,7 @@
 #include <torch/csrc/tensor/python_tensor.h>
 #include <torch/csrc/utils/disable_torch_function.h>
 #include <torch/csrc/utils/tensor_dtypes.h>
+#include <torch/csrc/utils/python_compat.h>
 #include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/utils/tensor_layouts.h>
 #include <torch/csrc/utils/tensor_memoryformats.h>
@@ -629,6 +630,9 @@
   {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
   {"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr},
   {"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr},
+  {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
+  {"_has_torch_function_unary", THPModule_has_torch_function_unary, METH_O, nullptr},
+  {"_has_torch_function_variadic", MAYBE_WRAP_FASTCALL(THPModule_has_torch_function_variadic), MAYBE_METH_FASTCALL, nullptr},
   {nullptr, nullptr, 0, nullptr}
 };
 
diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h
index cf05aec..41a2cca 100644
--- a/torch/csrc/autograd/python_variable.h
+++ b/torch/csrc/autograd/python_variable.h
@@ -24,18 +24,21 @@
 bool THPVariable_initModule(PyObject *module);
 THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);
 
-static inline bool THPVariable_CheckExact(PyObject *obj) {
+static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
   // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
   // (A subclass could have different semantics.) The one exception is
   // Parameter, which is used for Python bookkeeping but is equivalent to
   // Tensor as far as C++ is concerned.
-  auto obj_py_type = Py_TYPE(obj);
   return (
-    obj_py_type == (PyTypeObject*)THPVariableClass ||
-    obj_py_type == (PyTypeObject*)ParameterClass
+    tp == (PyTypeObject*)THPVariableClass ||
+    tp == (PyTypeObject*)ParameterClass
   );
 }
 
+static inline bool THPVariable_CheckExact(PyObject *obj) {
+  return THPVariable_CheckTypeExact(Py_TYPE(obj));
+}
+
 inline bool THPVariable_Check(PyObject *obj)
 {
   return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp
index d3d1a3f..28414dc 100644
--- a/torch/csrc/utils/disable_torch_function.cpp
+++ b/torch/csrc/utils/disable_torch_function.cpp
@@ -1,6 +1,7 @@
 #include <torch/csrc/utils/disable_torch_function.h>
 #include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/Exceptions.h>
+#include <torch/csrc/utils/python_strings.h>
 
 namespace torch {
   static thread_local bool enable_torch_function = true;
@@ -125,3 +126,111 @@
   return result;
   END_HANDLE_TH_ERRORS
 }
+
+// Makes sure that we don't check for __torch_function__ on basic Python types
+static bool is_basic_python_type(PyTypeObject *tp)
+{
+  return (
+    /* Basic number types */
+    tp == &PyBool_Type ||
+
+    tp == &PyLong_Type ||
+    tp == &PyFloat_Type ||
+    tp == &PyComplex_Type ||
+
+    /* Basic sequence types */
+    tp == &PyList_Type ||
+    tp == &PyTuple_Type ||
+    tp == &PyDict_Type ||
+    tp == &PySet_Type ||
+    tp == &PyFrozenSet_Type ||
+    tp == &PyUnicode_Type ||
+    tp == &PyBytes_Type ||
+
+    /* other builtins */
+    tp == &PySlice_Type ||
+    tp == Py_TYPE(Py_None) ||
+    tp == Py_TYPE(Py_Ellipsis) ||
+    tp == Py_TYPE(Py_NotImplemented) ||
+
+    PyModule_Check(tp) ||
+    /* sentinel to swallow trailing || */
+    false
+  );
+}
+
+inline bool has_torch_function_attr(PyObject* obj) {
+  auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
+  return (
+    attr.ptr() != nullptr &&
+    attr.ptr() != torch::disabled_torch_function);
+}
+
+namespace torch {
+auto check_has_torch_function(PyObject* obj) -> bool
+{
+  PyTypeObject *tp = Py_TYPE(obj);
+  return (
+    !THPVariable_CheckTypeExact(tp) &&
+    !is_basic_python_type(tp) &&
+    torch::torch_function_enabled() &&
+    has_torch_function_attr(obj)
+  );
+}
+} // namespace torch
+
+inline bool sequence_has_torch_function(PyObject* args) {
+  Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
+  for (Py_ssize_t i = 0; i < nargs; i++) {
+    PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
+    if (torch::check_has_torch_function(obj)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+inline bool array_has_torch_function(PyObject *const *args, Py_ssize_t nargs) {
+  for (Py_ssize_t i = 0; i < nargs; i++) {
+    if (torch::check_has_torch_function(args[i])) {
+      return true;
+    }
+  }
+  return false;
+}
+
+PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg) {
+  bool result;  // NOLINT(cppcoreguidelines-init-variables)
+  if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
+    // Fast path:
+    //   If we know that we have a tuple or list, we can skip an INCREF and
+    //   DECREF from PySequence_Fast. Core functions will always follow this
+    //   convention (almost always tuples), and it shaves ~3.5% off the cost of
+    //   the check.
+    result = sequence_has_torch_function(arg);
+  } else {
+    auto args = py::reinterpret_steal<py::object>(
+      PySequence_Fast(arg, "expected a sequence"));
+    result = sequence_has_torch_function(args.ptr());
+  }
+
+  if (result) {
+    Py_RETURN_TRUE;
+  }
+  Py_RETURN_FALSE;
+}
+
+PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj) {
+  // Special case `THPModule_has_torch_function` for the single arg case.
+  if (torch::check_has_torch_function(obj)) {
+    Py_RETURN_TRUE;
+  }
+  Py_RETURN_FALSE;
+}
+
+PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs) {
+  if (array_has_torch_function(args, nargs)) {
+    Py_RETURN_TRUE;
+  }
+  Py_RETURN_FALSE;
+}
diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h
index 1216660..c4f9651 100644
--- a/torch/csrc/utils/disable_torch_function.h
+++ b/torch/csrc/utils/disable_torch_function.h
@@ -9,8 +9,12 @@
   bool torch_function_enabled();
   PyObject* disabled_torch_function_impl();
   void set_disabled_torch_function_impl(PyObject* value);
+  bool check_has_torch_function(PyObject* obj);
 }
 
 PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused);
 PyObject* THPModule_DisableTorchFunctionType();
-PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
\ No newline at end of file
+PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args);
+PyObject* THPModule_has_torch_function(PyObject*, PyObject *arg);
+PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject *obj);
+PyObject* THPModule_has_torch_function_variadic(PyObject*, PyObject *const *args, Py_ssize_t nargs);
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index aa73b2a..4208f65 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -862,7 +862,7 @@
   }
 
   int i = 0;
-  if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
+  if (self != nullptr && check_has_torch_function(self)) {
     append_overloaded_arg(&this->overloaded_args, self);
   }
   for (auto& param : params) {
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 9fa4901..0f7f595 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -674,127 +674,6 @@
 }
 
 /*
- * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
- *
- * Stripped down version of PyObject_GetAttrString,
- * avoids lookups for None, tuple, and List objects,
- * and doesn't create a PyErr since this code ignores it.
- *
- * This can be much faster then PyObject_GetAttrString where
- * exceptions are not used by caller.
- *
- * 'obj' is the object to search for attribute.
- *
- * 'name' is the attribute to search for.
- *
- * Returns a py::object wrapping the return value. If the attribute lookup failed
- * the value will be NULL.
- *
- */
-
-static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
-{
-    PyTypeObject *tp = Py_TYPE(obj);
-    PyObject *res = (PyObject *)NULL;
-
-    /* Attribute referenced by (char *)name */
-    if (tp->tp_getattr != NULL) {
-        res = (*tp->tp_getattr)(obj, name);
-        if (res == NULL) {
-          PyErr_Clear();
-        }
-    }
-    /* Attribute referenced by (PyObject *)name */
-    else if (tp->tp_getattro != NULL) {
-        PyObject *w = THPUtils_internString(name);
-        if (w == NULL) {
-          return py::object();
-        }
-        res = (*tp->tp_getattro)(obj, w);
-        Py_DECREF(w);
-        if (res == NULL) {
-            PyErr_Clear();
-        }
-    }
-    return py::reinterpret_steal<py::object>(res);
-}
-
-// Makes sure that we don't check for __torch_function__ on basic Python types
-static bool _is_basic_python_type(PyTypeObject *tp)
-{
-  return (
-    /* Basic number types */
-    tp == &PyBool_Type ||
-
-    tp == &PyLong_Type ||
-    tp == &PyFloat_Type ||
-    tp == &PyComplex_Type ||
-
-    /* Basic sequence types */
-    tp == &PyList_Type ||
-    tp == &PyTuple_Type ||
-    tp == &PyDict_Type ||
-    tp == &PySet_Type ||
-    tp == &PyFrozenSet_Type ||
-    tp == &PyUnicode_Type ||
-    tp == &PyBytes_Type ||
-
-    /* other builtins */
-    tp == &PySlice_Type ||
-    tp == Py_TYPE(Py_None) ||
-    tp == Py_TYPE(Py_Ellipsis) ||
-    tp == Py_TYPE(Py_NotImplemented) ||
-
-    PyModule_Check(tp) ||
-    /* sentinel to swallow trailing || */
-    false
-  );
-}
-
-/*
- * Lookup a special method, following the python approach of looking up
- * on the type object, rather than on the instance itself.
- *
- * Assumes that the special method is a torch-specific one, so does not
- * look at builtin types, nor does it look at a base Tensor.
- *
- * If no special method is found, return NULL, otherwise returns a new
- * reference to the function object
- *
- * In future, could be made more like _Py_LookupSpecial
- */
-
-static py::object PyTorch_LookupSpecial(PyObject *obj, char* name)
-{
-  if (THPVariable_CheckExact(obj)) {
-      return py::object();
-  }
-  PyTypeObject *tp = Py_TYPE(obj);
-  if (_is_basic_python_type(tp)) {
-    return py::object();
-  }
-  return PyObject_FastGetAttrString((PyObject *)tp, name);
-}
-
-/*
- * Checks if obj has a __torch_function__ implementation
- *
- * Returns true if an implementation is found and false otherwise
- *
- */
-static auto check_has_torch_function(PyObject* obj) -> bool
-{
-  if (!torch_function_enabled()) {
-    return false;
-  }
-  py::object method = PyTorch_LookupSpecial(obj, "__torch_function__");
-  if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){
-    return true;
-  }
-  return false;
-}
-
-/*
  *
  * Handle __torch_function__ overrides if we know that there are overloaded
  * arguments.  All objects stored in r.overloaded_args must have a
diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h
index 7e1cb0c..48ce9c1 100644
--- a/torch/csrc/utils/python_compat.h
+++ b/torch/csrc/utils/python_compat.h
@@ -2,6 +2,33 @@
 
 #include <torch/csrc/python_headers.h>
 
+#if PY_VERSION_HEX < 0x03070000
+// METH_FASTCALL was introduced in Python 3.7, so we wrap _PyCFunctionFast
+// signatures for earlier versions.
+
+template <PyObject* (*f)(PyObject*, PyObject *const *, Py_ssize_t)>
+PyObject* maybe_wrap_fastcall(PyObject *module, PyObject *args) {
+  return f(
+    module,
+
+    // _PyTuple_ITEMS
+    //   Because this is only a compat shim for Python 3.6, we don't have
+    //   to worry about the representation changing.
+    ((PyTupleObject *)args)->ob_item,
+    PySequence_Fast_GET_SIZE(args)
+  );
+}
+
+#define MAYBE_METH_FASTCALL METH_VARARGS
+#define MAYBE_WRAP_FASTCALL(f) maybe_wrap_fastcall<f>
+
+#else
+
+#define MAYBE_METH_FASTCALL METH_FASTCALL
+#define MAYBE_WRAP_FASTCALL(f) (PyCFunction)(void(*)(void))f
+
+#endif
+
 // PyPy 3.6 does not yet have PySlice_Unpack
 #if PY_VERSION_HEX < 0x03060100 || defined(PYPY_VERSION)
 
diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h
index 55cf0c3..66e5bf1 100644
--- a/torch/csrc/utils/python_strings.h
+++ b/torch/csrc/utils/python_strings.h
@@ -4,6 +4,7 @@
 #include <stdexcept>
 #include <string>
 #include <torch/csrc/utils/object_ptr.h>
+#include <torch/csrc/utils/pybind.h>
 
 // Utilities for handling Python strings. Note that PyString, when defined, is
 // the same as PyBytes.
@@ -54,3 +55,49 @@
 inline void THPUtils_internStringInPlace(PyObject** obj) {
   PyUnicode_InternInPlace(obj);
 }
+
+/*
+ * Reference: https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
+ *
+ * Stripped down version of PyObject_GetAttrString,
+ * avoids lookups for None, tuple, and List objects,
+ * and doesn't create a PyErr since this code ignores it.
+ *
+ * This can be much faster then PyObject_GetAttrString where
+ * exceptions are not used by caller.
+ *
+ * 'obj' is the object to search for attribute.
+ *
+ * 'name' is the attribute to search for.
+ *
+ * Returns a py::object wrapping the return value. If the attribute lookup failed
+ * the value will be NULL.
+ *
+ */
+
+static py::object PyObject_FastGetAttrString(PyObject *obj, char *name)
+{
+    PyTypeObject *tp = Py_TYPE(obj);
+    PyObject *res = (PyObject *)nullptr;
+
+    /* Attribute referenced by (char *)name */
+    if (tp->tp_getattr != nullptr) {
+        res = (*tp->tp_getattr)(obj, name);
+        if (res == nullptr) {
+          PyErr_Clear();
+        }
+    }
+    /* Attribute referenced by (PyObject *)name */
+    else if (tp->tp_getattro != nullptr) {
+        auto w = py::reinterpret_steal<py::object>(
+          THPUtils_internString(name));
+        if (w.ptr() == nullptr) {
+          return py::object();
+        }
+        res = (*tp->tp_getattro)(obj, w.ptr());
+        if (res == nullptr) {
+            PyErr_Clear();
+        }
+    }
+    return py::reinterpret_steal<py::object>(res);
+}
diff --git a/torch/functional.py b/torch/functional.py
index 1442ab5..43fa0a3 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -6,7 +6,9 @@
 import torch.nn.functional as F
 from torch.types import _size
 from ._lowrank import svd_lowrank, pca_lowrank
-from .overrides import has_torch_function, handle_torch_function
+from .overrides import (
+    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+    handle_torch_function)
 from ._jit_internal import boolean_dispatch, List
 from ._jit_internal import _overload as overload
 from torch._autograd_functions import _LU
@@ -68,7 +70,7 @@
                 [0, 1, 2]])
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(broadcast_tensors, tensors, *tensors)
     return _VF.broadcast_tensors(tensors)  # type: ignore
 
@@ -146,7 +148,7 @@
                  [8, 9]]))
     """
     if not torch.jit.is_scripting():
-        if type(tensor) is not Tensor and has_torch_function((tensor,)):
+        if has_torch_function_unary(tensor):
             return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,
                                          dim=dim)
     # Overwriting reason:
@@ -235,10 +237,9 @@
         tensor(2.9802e-08)
     """
     if not torch.jit.is_scripting():
-        tens_ops = (LU_data, LU_pivots)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(LU_data, LU_pivots):
             return handle_torch_function(
-                lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,
+                lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots, unpack_data=unpack_data,
                 unpack_pivots=unpack_pivots)
     shape = LU_data.shape
     # In generalized LU factorization, the following shape relations hold:
@@ -398,7 +399,7 @@
                 [ 0.3311,  5.5201, -3.0356]])
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):
+        if has_torch_function(operands):
             return handle_torch_function(einsum, operands, equation, *operands)
     if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
         # the old interface of passing the operands as one list argument
@@ -448,7 +449,7 @@
 
 def _meshgrid(*tensors):
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(meshgrid, tensors, *tensors)
     if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
         # the old interface of passing the operands as one list argument
@@ -568,7 +569,7 @@
 
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
                 window=window, center=center, pad_mode=pad_mode, normalized=normalized,
@@ -650,7 +651,7 @@
         Tensor: Least squares estimation of the original signal of size (..., signal_length)
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
                 window=window, center=center, normalized=normalized, onesided=onesided,
@@ -734,7 +735,7 @@
 
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
                 return_counts=return_counts, dim=dim)
@@ -810,7 +811,7 @@
         tensor([2, 2, 1, 2, 1])
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 unique_consecutive, (input,), input, return_inverse=return_inverse,
                 return_counts=return_counts, dim=dim)
@@ -823,7 +824,7 @@
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_impl(input, sorted, return_inverse, return_counts, dim)
 
     output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -834,7 +835,7 @@
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_impl(input, sorted, return_inverse, return_counts, dim)
 
     output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -845,7 +846,7 @@
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_impl(input, sorted, return_inverse, return_counts, dim)
 
     output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
@@ -888,7 +889,7 @@
     # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
 
     output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -899,7 +900,7 @@
     # type: (Tensor, bool, bool, Optional[int]) -> Tensor
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
 
     output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -910,7 +911,7 @@
     # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
 
     output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
@@ -1000,7 +1001,7 @@
                 [ -0.2850,   4.2573,  -3.5997]])
     """
     if not torch.jit.is_scripting():
-        if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):
+        if has_torch_function_variadic(a, b):
             return handle_torch_function(tensordot, (a, b), a, b, dims=dims)
     if isinstance(dims, (list, tuple)) or \
        (isinstance(dims, torch.Tensor) and dims.numel() > 1):
@@ -1046,7 +1047,7 @@
                 [3, 5]])
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(cartesian_prod, tensors, *tensors)
     return _VF.cartesian_prod(tensors)  # type: ignore
 
@@ -1080,7 +1081,7 @@
                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
                 [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
     """
-    if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+    if has_torch_function(tensors):
         return handle_torch_function(block_diag, tensors, *tensors)
     return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore
 
@@ -1128,7 +1129,7 @@
                 [2.2830, 0.3791]])
     """
     if not torch.jit.is_scripting():
-        if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):
+        if has_torch_function_variadic(x1, x2):
             return handle_torch_function(
                 cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
     if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
@@ -1168,7 +1169,7 @@
         (tensor([0.5000]), tensor([1.]))
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(atleast_1d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
@@ -1203,7 +1204,7 @@
         (tensor([[0.5000]]), tensor([[1.]]))
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(atleast_2d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
@@ -1247,7 +1248,7 @@
         (tensor([[[0.5000]]]), tensor([[[1.]]]))
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
+        if has_torch_function(tensors):
             return handle_torch_function(atleast_3d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
@@ -1380,7 +1381,7 @@
     """
 
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
 
@@ -1476,7 +1477,7 @@
     .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
     """
     if not torch.jit.is_scripting():
-        if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):
+        if has_torch_function(matrices):
             return handle_torch_function(chain_matmul, matrices, *matrices)
     return _VF.chain_matmul(matrices)  # type: ignore
 
@@ -1596,7 +1597,7 @@
 def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
     # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
     if not torch.jit.is_scripting():
-        if type(A) is not Tensor and has_torch_function((A,)):
+        if has_torch_function_unary(A):
             return handle_torch_function(
                 lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
     result = _lu_impl(A, pivot, get_infos, out)
@@ -1612,7 +1613,7 @@
     # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
     # need to check for torch_function here so that we exit if
     if not torch.jit.is_scripting():
-        if type(A) is not Tensor and has_torch_function((A,)):
+        if has_torch_function_unary(A):
             return handle_torch_function(
                 lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
     result = _lu_impl(A, pivot, get_infos, out)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 162c22ea..2cfc1c2 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -9,12 +9,15 @@
 from torch._torch_docs import reproducibility_notes, tf32_notes
 
 from .._jit_internal import boolean_dispatch, _overload
-from ..overrides import has_torch_function, handle_torch_function
+from ..overrides import (
+    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+    handle_torch_function)
 from . import _reduction as _Reduction
 from . import grad  # noqa: F401
 from .modules import utils
 from .modules.utils import _single, _pair, _triple, _list_with_default
 
+
 Tensor = torch.Tensor
 
 conv1d = _add_docstr(
@@ -410,7 +413,7 @@
         http://arxiv.org/abs/1412.6071
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 fractional_max_pool2d_with_indices,
                 (input,),
@@ -438,7 +441,7 @@
 ):
     # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor  # noqa
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 fractional_max_pool2d,
                 (input,),
@@ -500,7 +503,7 @@
         http://arxiv.org/abs/1412.6071
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 fractional_max_pool3d_with_indices,
                 (input,),
@@ -532,7 +535,7 @@
 ):
     # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor  # noqa
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 fractional_max_pool3d,
                 (input,),
@@ -569,7 +572,7 @@
     See :class:`~torch.nn.MaxPool1d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool1d_with_indices,
                 (input,),
@@ -589,7 +592,7 @@
 def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor  # noqa
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool1d,
                 (input,),
@@ -627,7 +630,7 @@
     See :class:`~torch.nn.MaxPool2d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool2d_with_indices,
                 (input,),
@@ -647,7 +650,7 @@
 def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor  # noqa
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool2d,
                 (input,),
@@ -685,7 +688,7 @@
     See :class:`~torch.nn.MaxPool3d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool3d_with_indices,
                 (input,),
@@ -705,7 +708,7 @@
 def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor  # noqa
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_pool3d,
                 (input,),
@@ -773,7 +776,7 @@
     See :class:`~torch.nn.MaxUnpool1d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_unpool1d,
                 (input,),
@@ -805,7 +808,7 @@
     See :class:`~torch.nn.MaxUnpool2d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_unpool2d,
                 (input,),
@@ -833,7 +836,7 @@
     See :class:`~torch.nn.MaxUnpool3d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 max_unpool3d,
                 (input,),
@@ -863,7 +866,7 @@
     See :class:`~torch.nn.LPPool2d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
             )
@@ -885,7 +888,7 @@
     See :class:`~torch.nn.LPPool1d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode
             )
@@ -909,7 +912,7 @@
         return_indices: whether to return pooling indices. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices
             )
@@ -919,7 +922,7 @@
 def _adaptive_max_pool1d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices
             )
@@ -950,7 +953,7 @@
         return_indices: whether to return pooling indices. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices
             )
@@ -961,7 +964,7 @@
 def _adaptive_max_pool2d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList2[int], bool) -> Tensor
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices
             )
@@ -992,7 +995,7 @@
         return_indices: whether to return pooling indices. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices
             )
@@ -1003,7 +1006,7 @@
 def _adaptive_max_pool3d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList3[int], bool) -> Tensor
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices
             )
@@ -1050,7 +1053,7 @@
             double-integer tuple)
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
     _output_size = _list_with_default(output_size, input.size())
     return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@@ -1069,7 +1072,7 @@
             triple-integer tuple)
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
     _output_size = _list_with_default(output_size, input.size())
     return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
@@ -1090,7 +1093,7 @@
         inplace: If set to ``True``, will do this operation in-place. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace)
     if p < 0.0 or p > 1.0:
         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1103,7 +1106,7 @@
     See :class:`~torch.nn.AlphaDropout` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace)
     if p < 0.0 or p > 1.0:
         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1126,7 +1129,7 @@
         inplace: If set to ``True``, will do this operation in-place. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace)
     if p < 0.0 or p > 1.0:
         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1151,7 +1154,7 @@
     # This is 100% the same code as dropout2d. We duplicate this code so that
     # stack traces are not confusing.
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace)
     if p < 0.0 or p > 1.0:
         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
@@ -1179,7 +1182,7 @@
         inplace: If set to ``True``, will do this operation in-place. Default: ``False``
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace
             )
@@ -1194,7 +1197,7 @@
     See :class:`~torch.nn.Threshold` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace)
     if inplace:
         result = _VF.threshold_(input, threshold, value)
@@ -1225,7 +1228,7 @@
     :class:`~torch.nn.ReLU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(relu, (input,), input, inplace=inplace)
     if inplace:
         result = torch.relu_(input)
@@ -1263,7 +1266,7 @@
         dim (int): dimension on which to split the input. Default: -1
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(glu, (input,), input, dim=dim)
     if input.dim() == 0:
         raise RuntimeError("glu does not support scalars because halving size must be even")
@@ -1278,7 +1281,7 @@
     details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace)
     if inplace:
         result = torch._C._nn.hardtanh_(input, min_val, max_val)
@@ -1305,7 +1308,7 @@
     See :class:`~torch.nn.ReLU6` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(relu6, (input,), input, inplace=inplace)
     return hardtanh(input, 0.0, 6.0, inplace)
 
@@ -1317,7 +1320,7 @@
     See :class:`~torch.nn.ELU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace)
     if inplace:
         result = torch._C._nn.elu_(input, alpha)
@@ -1347,7 +1350,7 @@
     See :class:`~torch.nn.SELU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(selu, (input,), input, inplace=inplace)
     if inplace:
         result = torch.selu_(input)
@@ -1375,7 +1378,7 @@
     See :class:`~torch.nn.CELU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace)
     if inplace:
         result = torch.celu_(input, alpha)
@@ -1404,7 +1407,7 @@
     See :class:`~torch.nn.LeakyReLU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace)
     if inplace:
         result = torch._C._nn.leaky_relu_(input, negative_slope)
@@ -1433,7 +1436,7 @@
     See :class:`~torch.nn.PReLU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(prelu, (input,), input, weight)
     return torch.prelu(input, weight)
 
@@ -1448,7 +1451,7 @@
     See :class:`~torch.nn.RReLU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace
             )
@@ -1491,7 +1494,7 @@
     See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(gelu, (input,), input)
     return torch._C._nn.gelu(input)
 
@@ -1505,7 +1508,7 @@
     See :class:`~torch.nn.Hardshrink` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(hardshrink, (input,), input, lambd=lambd)
     return torch.hardshrink(input, lambd)
 
@@ -1518,7 +1521,7 @@
     See :class:`~torch.nn.Tanhshrink` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(tanhshrink, (input,), input)
     return input - input.tanh()
 
@@ -1531,7 +1534,7 @@
     See :class:`~torch.nn.Softsign` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(softsign, (input,), input)
     return input / (input.abs() + 1)
 
@@ -1580,7 +1583,7 @@
           is performed. This is useful for preventing data type overflows. Default: None.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
     if dim is None:
         dim = _get_softmax_dim("softmin", input.dim(), _stacklevel)
@@ -1617,7 +1620,7 @@
 
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
     if dim is None:
         dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
@@ -1669,7 +1672,7 @@
         https://arxiv.org/abs/1611.01144
     """
     if not torch.jit.is_scripting():
-        if type(logits) is not Tensor and has_torch_function((logits,)):
+        if has_torch_function_unary(logits):
             return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim)
     if eps != 1e-10:
         warnings.warn("`eps` parameter is deprecated and has no effect.")
@@ -1708,7 +1711,7 @@
           is performed. This is useful for preventing data type overflows. Default: None.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
     if dim is None:
         dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
@@ -1772,7 +1775,7 @@
     See :class:`~torch.nn.Hardsigmoid` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
     if inplace:
         return torch._C._nn.hardsigmoid_(input)
@@ -1793,10 +1796,9 @@
         - Bias: :math:`(out\_features)`
         - Output: :math:`(N, *, out\_features)`
     """
-    tens_ops = (input, weight)
     if not torch.jit.is_scripting():
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-            return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
+        if has_torch_function_variadic(input, weight):
+            return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
     if input.dim() == 2 and bias is not None:
         # fused op is marginally faster
         ret = torch.addmm(bias, input, weight.t())
@@ -1845,7 +1847,7 @@
     See :class:`~torch.nn.SiLU` for more details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(silu, (input,), input, inplace=inplace)
     if inplace:
         return torch._C._nn.silu_(input)
@@ -1870,7 +1872,7 @@
         https://arxiv.org/abs/1905.02244
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(hardswish, (input,), input, inplace=inplace)
     if inplace:
         return torch._C._nn.hardswish_(input)
@@ -2058,11 +2060,10 @@
     """
 
     if not torch.jit.is_scripting():
-        tens_ops = (input, weight)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, weight):
             return handle_torch_function(
                 embedding_bag,
-                tens_ops,
+                (input, weight),
                 input,
                 weight,
                 offsets=offsets,
@@ -2188,7 +2189,7 @@
     :class:`~torch.nn.BatchNorm3d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 batch_norm,
                 (input,),
@@ -2227,7 +2228,7 @@
     :class:`~torch.nn.InstanceNorm3d` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 instance_norm,
                 (input,),
@@ -2258,7 +2259,7 @@
     See :class:`~torch.nn.LayerNorm` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps
             )
@@ -2273,7 +2274,7 @@
     See :class:`~torch.nn.GroupNorm` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps)
     _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))
     return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
@@ -2287,7 +2288,7 @@
     See :class:`~torch.nn.LocalResponseNorm` for details.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k)
     dim = input.dim()
     if dim < 3:
@@ -2425,11 +2426,10 @@
         >>> output.backward()
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 nll_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 weight=weight,
@@ -2522,11 +2522,10 @@
 
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 poisson_nll_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 log_input=log_input,
@@ -2593,11 +2592,10 @@
         In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 kl_div,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 size_average=size_average,
@@ -2679,11 +2677,10 @@
         >>> loss.backward()
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 cross_entropy,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 weight=weight,
@@ -2739,11 +2736,10 @@
         >>> loss.backward()
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 binary_cross_entropy,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 weight=weight,
@@ -2813,11 +2809,10 @@
          >>> loss.backward()
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 binary_cross_entropy_with_logits,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 weight=weight,
@@ -2851,11 +2846,10 @@
     See :class:`~torch.nn.SmoothL1Loss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 smooth_l1_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 size_average=size_average,
@@ -2891,10 +2885,9 @@
     See :class:`~torch.nn.L1Loss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
-                l1_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+                l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
             )
     if not (target.size() == input.size()):
         warnings.warn(
@@ -2924,10 +2917,9 @@
     See :class:`~torch.nn.MSELoss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
-                mse_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+                mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
             )
     if not (target.size() == input.size()):
         warnings.warn(
@@ -2957,11 +2949,10 @@
     See :class:`~torch.nn.MarginRankingLoss` for details.
     """  # noqa
     if not torch.jit.is_scripting():
-        tens_ops = (input1, input2, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input1, input2, target):
             return handle_torch_function(
                 margin_ranking_loss,
-                tens_ops,
+                (input1, input2, target),
                 input1,
                 input2,
                 target,
@@ -2997,11 +2988,10 @@
     See :class:`~torch.nn.HingeEmbeddingLoss` for details.
     """  # noqa
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 hinge_embedding_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 margin=margin,
@@ -3028,11 +3018,10 @@
     See :class:`~torch.nn.MultiLabelMarginLoss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 multilabel_margin_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 size_average=size_average,
@@ -3058,10 +3047,9 @@
     See :class:`~torch.nn.SoftMarginLoss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
-                soft_margin_loss, tens_ops, input, target, size_average=size_average, reduce=reduce, reduction=reduction
+                soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
             )
     if size_average is not None or reduce is not None:
         reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
@@ -3083,11 +3071,10 @@
     See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 multilabel_soft_margin_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 weight=weight,
@@ -3131,11 +3118,10 @@
     See :class:`~torch.nn.CosineEmbeddingLoss` for details.
     """  # noqa
     if not torch.jit.is_scripting():
-        tens_ops = (input1, input2, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input1, input2, target):
             return handle_torch_function(
                 cosine_embedding_loss,
-                tens_ops,
+                (input1, input2, target),
                 input1,
                 input2,
                 target,
@@ -3167,11 +3153,10 @@
     See :class:`~torch.nn.MultiMarginLoss` for details.
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, target)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, target):
             return handle_torch_function(
                 multi_margin_loss,
-                tens_ops,
+                (input, target),
                 input,
                 target,
                 p=p,
@@ -3460,7 +3445,7 @@
         {backward_reproducibility_note}
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 interpolate,
                 (input,),
@@ -3824,10 +3809,9 @@
     .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908
     """
     if not torch.jit.is_scripting():
-        tens_ops = (input, grid)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(input, grid):
             return handle_torch_function(
-                grid_sample, tens_ops, input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
+                grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
             )
     if mode != "bilinear" and mode != "nearest" and mode != "bicubic":
         raise ValueError(
@@ -3916,7 +3900,7 @@
         (the center of the input image).
     """
     if not torch.jit.is_scripting():
-        if type(theta) is not Tensor and has_torch_function((theta,)):
+        if has_torch_function_unary(theta):
             return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners)
     if align_corners is None:
         warnings.warn(
@@ -4025,7 +4009,7 @@
 
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value)
     assert len(pad) % 2 == 0, "Padding length must be divisible by 2"
     assert len(pad) // 2 <= input.dim(), "Padding length too large"
@@ -4208,11 +4192,10 @@
     See :class:`~torch.nn.TripletMarginLoss` for details
     """
     if not torch.jit.is_scripting():
-        tens_ops = (anchor, positive, negative)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function_variadic(anchor, positive, negative):
             return handle_torch_function(
                 triplet_margin_loss,
-                tens_ops,
+                (anchor, positive, negative),
                 anchor,
                 positive,
                 negative,
@@ -4250,11 +4233,10 @@
             "functions requiring Callables cannot be scripted."
         )
 
-    tens_ops = (anchor, positive, negative)
-    if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+    if has_torch_function_variadic(anchor, positive, negative):
         return handle_torch_function(
             triplet_margin_with_distance_loss,
-            tens_ops,
+            (anchor, positive, negative),
             anchor,
             positive,
             negative,
@@ -4304,7 +4286,7 @@
                                 operation won't be differentiable.
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out)
     if out is None:
         denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
@@ -4337,7 +4319,7 @@
     See :class:`torch.nn.Unfold` for details
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride
             )
@@ -4365,7 +4347,7 @@
     See :class:`torch.nn.Fold` for details
     """
     if not torch.jit.is_scripting():
-        if type(input) is not Tensor and has_torch_function((input,)):
+        if has_torch_function_unary(input):
             return handle_torch_function(
                 fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
             )
@@ -4633,7 +4615,7 @@
     """
     if not torch.jit.is_scripting():
         tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
-        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+        if has_torch_function(tens_ops):
             return handle_torch_function(
                 multi_head_attention_forward,
                 tens_ops,
diff --git a/torch/overrides.py b/torch/overrides.py
index 889bf5b..96c7f82 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -28,7 +28,9 @@
 from typing import Dict, Set, List, Any, Callable, Iterable
 
 import torch
-from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl
+from torch._C import (
+    _has_torch_function, _has_torch_function_unary,
+    _has_torch_function_variadic, _add_docstr)
 
 __all__ = [
     "get_ignored_functions",
@@ -167,6 +169,8 @@
         torch.nn.functional.upsample_bilinear,
         torch.nn.functional.upsample_nearest,
         torch.nn.functional.has_torch_function,
+        torch.nn.functional.has_torch_function_unary,
+        torch.nn.functional.has_torch_function_variadic,
         torch.nn.functional.handle_torch_function,
         torch.nn.functional.sigmoid,
         torch.nn.functional.hardsigmoid,
@@ -1193,33 +1197,52 @@
                     '__torch_function__: {}'
                     .format(func_name, [type(arg) for arg in overloaded_args]))
 
-def has_torch_function(relevant_args: Iterable[Any]) -> bool:
-    """Check for __torch_function__ implementations in the elements of an iterable.
-
-    Considers exact ``Tensor`` s non-dispatchable.
-
+has_torch_function = _add_docstr(
+    _has_torch_function,
+    r"""Check for __torch_function__ implementations in the elements of an iterable.
+    Considers exact ``Tensor`` s and ``Parameter`` s non-dispatchable.
     Arguments
     ---------
     relevant_args : iterable
         Iterable or aguments to check for __torch_function__ methods.
-
     Returns
     -------
     bool
         True if any of the elements of relevant_args have __torch_function__
         implementations, False otherwise.
-
     See Also
     ________
     torch.is_tensor_like
         Checks if something is a Tensor-like, including an exact ``Tensor``.
     """
-    return _is_torch_function_enabled() and any(
-        type(a) is not torch.Tensor and
-        getattr(a, '__torch_function__', _disabled_torch_function_impl)
-        is not _disabled_torch_function_impl
-        for a in relevant_args
-    )
+)
+
+has_torch_function_unary = _add_docstr(
+    _has_torch_function_unary,
+    r"""Special case of `has_torch_function` for single inputs.
+    Instead of:
+      `has_torch_function((t,))`
+    call:
+      `has_torch_function_unary(t)`
+    which skips unnecessary packing and unpacking work.
+    """
+)
+
+has_torch_function_variadic = _add_docstr(
+    _has_torch_function_variadic,
+    r"""Special case of `has_torch_function` that skips tuple creation.
+
+    This uses the METH_FASTCALL protocol introduced in Python 3.7; for 3.6
+    and before it has roughly equivilent performance compared to
+    `has_torch_function`.
+
+    Instead of:
+      `has_torch_function((a, b))`
+    call:
+      `has_torch_function_variadic(a, b)`
+    which skips unnecessary packing and unpacking work.
+    """
+)
 
 @functools.lru_cache(None)
 def get_overridable_functions() -> Dict[Any, List[Callable]]:
diff --git a/torch/tensor.py b/torch/tensor.py
index bd8f347..eedffc7 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -10,7 +10,9 @@
 from torch._namedtensor_internals import (
     update_names, check_serializing_named_tensor, resolve_ellipsis,
     unzip_namedshape, single_ellipsis_index, is_ellipsis)
-from torch.overrides import has_torch_function, handle_torch_function
+from torch.overrides import (
+    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+    handle_torch_function)
 import torch.utils.hooks as hooks
 
 
@@ -21,7 +23,7 @@
 
     @functools.wraps(f, assigned=assigned)
     def wrapped(*args, **kwargs):
-        if not all(type(t) is Tensor for t in args) and has_torch_function(args):
+        if has_torch_function(args):
             return handle_torch_function(wrapped, args, *args, **kwargs)
         try:
             return f(*args, **kwargs)
@@ -39,9 +41,8 @@
 # otherwise, it will not show up in autocomplete.
 class Tensor(torch._C._TensorBase):
     def __deepcopy__(self, memo):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__deepcopy__, relevant_args, self, memo)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
         if not self.is_leaf:
             raise RuntimeError("Only Tensors created explicitly by the user "
                                "(graph leaves) support the deepcopy protocol at the moment")
@@ -80,9 +81,8 @@
             return new_tensor
 
     def __reduce_ex__(self, proto):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
         check_serializing_named_tensor(self)
         # See Note [Don't serialize hooks]
         torch.utils.hooks.warn_if_has_hooks(self)
@@ -148,9 +148,8 @@
             return (torch._utils._rebuild_tensor_v2, args)
 
     def __setstate__(self, state):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__setstate__, relevant_args, self, state)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__setstate__, (self,), self, state)
         # Warning: this method is NOT called when you torch.load() a tensor;
         # that is managed by _rebuild_tensor_v2
         if not self.is_leaf:
@@ -168,9 +167,8 @@
         self.requires_grad, _, self._backward_hooks = state
 
     def __repr__(self):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__repr__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__repr__, (self,), self)
         # All strings are unicode in Python 3.
         return torch._tensor_str._str(self)
 
@@ -215,11 +213,10 @@
                 used to compute the attr::tensors. All the provided inputs must be leaf
                 Tensors.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             return handle_torch_function(
                 Tensor.backward,
-                relevant_args,
+                (self,),
                 self,
                 gradient=gradient,
                 retain_graph=retain_graph,
@@ -256,9 +253,8 @@
 
             >>> h.remove()  # removes the hook
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.register_hook, relevant_args, self, hook)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.register_hook, (self,), self, hook)
         if not self.requires_grad:
             raise RuntimeError("cannot register a hook on a tensor that "
                                "doesn't require gradient")
@@ -324,9 +320,8 @@
 
     def retain_grad(self):
         r"""Enables .grad attribute for non-leaf Tensors."""
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.retain_grad, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.retain_grad, (self,), self)
         if not self.requires_grad:
             raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
         if self.is_leaf:  # no-op for leaves
@@ -355,9 +350,8 @@
 
         This is always ``True`` for CUDA tensors.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.is_shared, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.is_shared, (self,), self)
         return self.storage().is_shared()
 
     def share_memory_(self):
@@ -366,17 +360,15 @@
         This is a no-op if the underlying storage is already in shared memory
         and for CUDA tensors. Tensors in shared memory cannot be resized.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.share_memory_, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.share_memory_, (self,), self)
         self.storage().share_memory_()
         return self
 
     def __reversed__(self):
         r"""Reverses the tensor along dimension 0."""
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__reversed__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__reversed__, (self,), self)
         if self.dim() == 0:
             return self
         else:
@@ -384,17 +376,15 @@
 
     def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
         r"""See :func:`torch.norm`"""
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.norm, relevant_args, self, p=p, dim=dim, keepdim=keepdim, dtype=dtype)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype)
         return torch.norm(self, p, dim, keepdim, dtype=dtype)
 
     def lu(self, pivot=True, get_infos=False):
         r"""See :func:`torch.lu`"""
         # If get_infos is True, then we don't need to check for errors and vice versa
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)
 
         if not torch._jit_internal.is_scripting():
             if self.requires_grad:
@@ -434,10 +424,9 @@
           This function changed signature at version 0.4.1. Calling with
           the previous signature may cause error or return incorrect result.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             return handle_torch_function(
-                Tensor.stft, relevant_args, self, n_fft, hop_length=hop_length,
+                Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
                 win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
                 onesided=onesided, return_complex=return_complex
             )
@@ -450,10 +439,9 @@
               onesided: Optional[bool] = None, length: Optional[int] = None,
               return_complex: bool = False):
         r"""See :func:`torch.istft`"""
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             return handle_torch_function(
-                Tensor.istft, relevant_args, self, n_fft, hop_length=hop_length, win_length=win_length,
+                Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
                 window=window, center=center, normalized=normalized, onesided=onesided, length=length,
                 return_complex=return_complex
             )
@@ -461,17 +449,15 @@
                            normalized, onesided, length, return_complex=return_complex)
 
     def resize(self, *sizes):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.resize, relevant_args, self, *sizes)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.resize, (self,), self, *sizes)
         warnings.warn("non-inplace resize is deprecated")
         from torch.autograd._functions import Resize
         return Resize.apply(self, sizes)
 
     def resize_as(self, tensor):
-        relevant_args = (self, tensor)
-        if type(self) is not Tensor and type(tensor) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.resize_as, relevant_args, self, tensor)
+        if has_torch_function_variadic(self, tensor):
+            return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
         warnings.warn("non-inplace resize_as is deprecated")
         from torch.autograd._functions import Resize
         return Resize.apply(self, tensor.size())
@@ -479,9 +465,8 @@
     def split(self, split_size, dim=0):
         r"""See :func:`torch.split`
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.split, relevant_args, self, split_size, dim=dim)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.split, (self,), self, split_size, dim=dim)
         if isinstance(split_size, int):
             return super(Tensor, self).split(split_size, dim)
         elif isinstance(split_size, Tensor):
@@ -498,10 +483,9 @@
 
         See :func:`torch.unique`
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             return handle_torch_function(
-                Tensor.unique, relevant_args, self, sorted=sorted, return_inverse=return_inverse,
+                Tensor.unique, (self,), self, sorted=sorted, return_inverse=return_inverse,
                 return_counts=return_counts, dim=dim
             )
         return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
@@ -511,24 +495,21 @@
 
         See :func:`torch.unique_consecutive`
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             return handle_torch_function(
-                Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse,
+                Tensor.unique_consecutive, (self,), self, return_inverse=return_inverse,
                 return_counts=return_counts, dim=dim
             )
         return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
 
     def __rsub__(self, other):
-        relevant_args = (self, other)
-        if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__rsub__, relevant_args, self, other)
+        if has_torch_function_variadic(self, other):
+            return handle_torch_function(Tensor.__rsub__, (self, other), self, other)
         return _C._VariableFunctions.rsub(self, other)
 
     def __rdiv__(self, other):
-        relevant_args = (self, other)
-        if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other)
+        if has_torch_function_variadic(self, other):
+            return handle_torch_function(Tensor.__rdiv__, (self, other), self, other)
         return self.reciprocal() * other
 
     __rtruediv__ = __rdiv__
@@ -537,17 +518,15 @@
     __pow__ = _C._TensorBase.pow
 
     def __format__(self, format_spec):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__format__, relevant_args, self, format_spec)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
         if self.dim() == 0:
             return self.item().__format__(format_spec)
         return object.__format__(self, format_spec)
 
     def __ipow__(self, other):  # type: ignore[misc]
-        relevant_args = (self, other)
-        if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__ipow__, relevant_args, self, other)
+        if has_torch_function_variadic(self, other):
+            return handle_torch_function(Tensor.__ipow__, (self, other), self, other)
         return NotImplemented
 
     @_wrap_type_error_to_not_implemented
@@ -567,9 +546,8 @@
     __abs__ = _C._TensorBase.abs
 
     def __len__(self):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__len__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__len__, (self,), self)
         if self.dim() == 0:
             raise TypeError("len() of a 0-d tensor")
         return self.shape[0]
@@ -581,9 +559,8 @@
         # (e.g., if you zip(*hiddens), the eager map will force all the
         # indexes of hiddens[0] before hiddens[1], while the generator
         # map will interleave them.)
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__iter__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__iter__, (self,), self)
         if self.dim() == 0:
             raise TypeError('iteration over a 0-d tensor')
         if torch._C._get_tracing_state():
@@ -594,15 +571,13 @@
         return iter(self.unbind(0))
 
     def __hash__(self):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__hash__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__hash__, (self,), self)
         return id(self)
 
     def __dir__(self):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__dir__, relevant_args, self)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__dir__, (self,), self)
         if self.is_quantized:
             warnings.warn('Only a small subset of methods are supported for quantized tensors.')
         tensor_methods = dir(self.__class__)
@@ -620,9 +595,8 @@
     __array_priority__ = 1000    # prefer Tensor ops over numpy ones
 
     def __array__(self, dtype=None):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__array__, relevant_args, self, dtype=dtype)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
         if dtype is None:
             return self.numpy()
         else:
@@ -631,9 +605,8 @@
     # Wrap Numpy array again in a suitable tensor when done, to support e.g.
     # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
     def __array_wrap__(self, array):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__array_wrap__, relevant_args, self, array=array)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__array_wrap__, (self,), self, array=array)
         if array.dtype == bool:
             # Workaround, torch has no built-in bool tensor
             array = array.astype('uint8')
@@ -646,9 +619,8 @@
             element (Tensor or scalar): element to be checked
                 for presence in current tensor"
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__contains__, (self,), self, element)
         if isinstance(element, (torch.Tensor, Number)):
             # type hint doesn't understand the __contains__ result array
             return (element == self).any().item()  # type: ignore[union-attr]
@@ -665,10 +637,9 @@
         See:
         https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
-            return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self)  # type: ignore[attr-defined]
+            return handle_torch_function(Tensor.__cuda_array_interface__.__get__, (self,), self)  # type: ignore[attr-defined]
 
         # raise AttributeError for unsupported tensors, so that
         # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
@@ -761,9 +732,8 @@
             The named tensor API is experimental and subject to change.
 
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.refine_names, relevant_args, self, *names)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.refine_names, (self,), self, *names)
         names = resolve_ellipsis(names, self.names, 'refine_names')
         return super(Tensor, self).refine_names(names)
 
@@ -803,9 +773,8 @@
             The named tensor API is experimental and subject to change.
 
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.align_to, relevant_args, self, *names)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.align_to, (self,), self, *names)
         ellipsis_idx = single_ellipsis_index(names, 'align_to')
         if ellipsis_idx is None:
             return super(Tensor, self).align_to(names)
@@ -839,9 +808,8 @@
         .. warning::
             The named tensor API is experimental and subject to change.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.unflatten, relevant_args, self, dim, sizes)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
 
         if not sizes:
             raise RuntimeError("unflatten: sizes must be non-empty")
@@ -855,9 +823,8 @@
     def rename_(self, *names, **rename_map):
         """In-place version of :meth:`~Tensor.rename`."""
 
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.rename_, relevant_args, self, *names, **rename_map)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.rename_, (self,), self, *names, **rename_map)
 
         # Note [rename_ / rename API]
         # The Python API for these is different from the C++ API. In Python:
@@ -900,17 +867,15 @@
             The named tensor API is experimental and subject to change.
 
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor.rename, relevant_args, self, *names, **rename_map)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.rename, (self,), self, *names, **rename_map)
 
         # See Note [rename_ / rename API]
         return update_names(self, names, rename_map, inplace=False)
 
     def _update_names(self, names, inplace):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
-            return handle_torch_function(Tensor._update_names, relevant_args, self, names, inplace)
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor._update_names, (self,), self, names, inplace)
 
         # See Note [rename_ / rename API]
         if inplace:
@@ -926,10 +891,9 @@
         The attribute will then contain the gradients computed and future calls to
         :func:`backward` will accumulate (add) gradients into it.
         """
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
-            return handle_torch_function(Tensor.grad.__get__, relevant_args, self)  # type: ignore[attr-defined]
+            return handle_torch_function(Tensor.grad.__get__, (self,), self)  # type: ignore[attr-defined]
 
         if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
             warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
@@ -941,18 +905,16 @@
 
     @grad.setter
     def grad(self, new_grad):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
-            return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad)  # type: ignore[attr-defined]
+            return handle_torch_function(Tensor.grad.__set__, (self,), self, new_grad)  # type: ignore[attr-defined]
         self._grad = new_grad
 
     @grad.deleter
     def grad(self):
-        relevant_args = (self,)
-        if type(self) is not Tensor and has_torch_function(relevant_args):
+        if has_torch_function_unary(self):
             # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
-            return handle_torch_function(Tensor.grad.__delete__, relevant_args, self)  # type: ignore[attr-defined]
+            return handle_torch_function(Tensor.grad.__delete__, (self,), self)  # type: ignore[attr-defined]
         del self._grad
 
     @classmethod