rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)

First half of #87990. This doesn't change any of the behavior and is just a rename

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218
Approved by: https://github.com/ezyang, https://github.com/zou3519
diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json
index ba4a2e9..8a66dc1 100644
--- a/test/allowlist_for_publicAPI.json
+++ b/test/allowlist_for_publicAPI.json
@@ -1128,7 +1128,7 @@
     "BFloat16Tensor",
     "ComplexDoubleStorage",
     "ComplexFloatStorage",
-    "DisableTorchFunction",
+    "DisableTorchFunctionSubclass",
     "Generator",
     "HalfStorage",
     "HalfTensor",
diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py
index d4a31c6..2105302 100644
--- a/test/profiler/test_profiler_tree.py
+++ b/test/profiler/test_profiler_tree.py
@@ -26,7 +26,7 @@
     "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
     "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
     "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
-    "<built-in method __exit__ of torch._C.DisableTorchFunction object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
+    "<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
     "cudaStreamIsCapturing": PRUNE_ALL,
     "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL,
 }
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 7082f75..01c763a 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -1448,7 +1448,7 @@
 
         x = B(torch.randn(5))
         with A():
-            with torch._C.DisableTorchFunction():
+            with torch._C.DisableTorchFunctionSubclass():
                 self.assertNotIsInstance(torch.sum(x), B)
 
         self.assertTrue(called)
@@ -1460,7 +1460,7 @@
             pass
 
         x = A(torch.randn(5))
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             g = torch._C._EnableTorchFunction()
             try:
                 self.assertIsInstance(torch.sum(x), A)
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 4d2df65..6897c31 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -99,7 +99,7 @@
             "device",
             "DeviceObjType",
             "DictType",
-            "DisableTorchFunction",
+            "DisableTorchFunctionSubclass",
             "DispatchKey",
             "DispatchKeySet",
             "dtype",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 2d20da2..79dd638 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -108,7 +108,7 @@
     ...
 
 # Defined in torch/csrc/utils/disable_torch_function.cpp
-def DisableTorchFunction(): ...
+def DisableTorchFunctionSubclass(): ...
 
 # Defined in torch/csrc/utils/tensor_layouts.cpp
 strided : layout = ...
diff --git a/torch/__init__.py b/torch/__init__.py
index ae55f59..ef6138c 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -288,7 +288,7 @@
         if (isinstance(obj, Callable) or inspect.isclass(obj)):  # type: ignore[arg-type]
             if (obj.__module__ != 'torch'):
                 # TODO: fix their module from C++ side
-                if name not in ['DisableTorchFunction', 'Generator']:
+                if name not in ['DisableTorchFunctionSubclass', 'Generator']:
                     obj.__module__ = 'torch'
 
 if not TYPE_CHECKING:
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index d3c5140..9d87897 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -506,7 +506,7 @@
                 )
             # Disable __torch_function__ to prevent cloning of `value` to hit
             # us
-            with torch._C.DisableTorchFunction():
+            with torch._C.DisableTorchFunctionSubclass():
                 if is_constant_source(self.get_source()):
                     return self.tx.output.register_attr_or_module(
                         value,
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index da32712..6e4325b 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -538,7 +538,7 @@
             options = VariableTracker.propagate(self, new_args, new_kwargs.values())
             # Disable __torch_function__ here to prevent the clone of the
             # example tensor from going into the override.
-            with torch._C.DisableTorchFunction():
+            with torch._C.DisableTorchFunctionSubclass():
                 if isinstance(args[0], TorchVariable):
                     return TensorVariable.create(
                         tx=tx,
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 315c2b1..5a30f83 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -704,7 +704,7 @@
 
         # Disable __torch_function__ here to prevent the clone of the
         # example tensor from going into the override.
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             return tx.inline_user_function_return(tf_func_var, tf_args, {})
 
 
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 14f5cd2..79af51e 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -1093,5 +1093,5 @@
             memo[id(tensor)] = out
             return out
         else:
-            with torch._C.DisableTorchFunction():
+            with torch._C.DisableTorchFunctionSubclass():
                 return func(*args, **kwargs)
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 793034b..41b6569 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -1297,7 +1297,7 @@
         if not all(issubclass(cls, t) for t in types):
             return NotImplemented
 
-        with _C.DisableTorchFunction():
+        with _C.DisableTorchFunctionSubclass():
             ret = func(*args, **kwargs)
             if func in get_default_nowrap_functions():
                 return ret
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index b8693a4..efe6c18 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -1594,8 +1594,8 @@
       (PyObject*)THPDefaultCPUGenerator,
       /* incref= */ false));
   ASSERT_TRUE(set_module_attr(
-      "DisableTorchFunction",
-      (PyObject*)THPModule_DisableTorchFunctionType(),
+      "DisableTorchFunctionSubclass",
+      (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
       /* incref= */ false));
   torch::set_disabled_torch_function_impl(
       PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index ee96323..d26db95 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -343,7 +343,6 @@
       _C_m, "_RestorePythonTLSSnapshot")
       .def(py::init<>());
 
-  // TODO: line up this binding with DisableTorchFunction
   py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
       .def(py::init<>());
   py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp
index 682120d..516e6b8 100644
--- a/torch/csrc/utils/disable_torch_function.cpp
+++ b/torch/csrc/utils/disable_torch_function.cpp
@@ -35,18 +35,20 @@
   PyObject_HEAD
       /* Type-specific fields go here. */
       bool old_state;
-} DisableTorchFunction;
+} DisableTorchFunctionSubclass;
 
-PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
-  ((DisableTorchFunction*)self)->old_state =
+PyObject* DisableTorchFunctionSubclass__enter(
+    PyObject* self,
+    PyObject* unused) {
+  ((DisableTorchFunctionSubclass*)self)->old_state =
       at::impl::PythonTorchFunctionTLS::is_disabled();
   at::impl::PythonTorchFunctionTLS::set_disabled(true);
   Py_RETURN_NONE;
 }
 
-PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
+PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
   at::impl::PythonTorchFunctionTLS::set_disabled(
-      ((DisableTorchFunction*)self)->old_state);
+      ((DisableTorchFunctionSubclass*)self)->old_state);
   Py_RETURN_NONE;
 }
 
@@ -58,16 +60,16 @@
   }
 }
 
-static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
-    {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
-    {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
+static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
+    {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
+    {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
     {nullptr, nullptr, 0, nullptr}};
 
-PyTypeObject DisableTorchFunctionType = {
+PyTypeObject DisableTorchFunctionSubclassType = {
     PyVarObject_HEAD_INIT(
         nullptr,
-        0) "torch._C.DisableTorchFunction", /* tp_name */
-    sizeof(DisableTorchFunction), /* tp_basicsize */
+        0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
+    sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
     0, /* tp_itemsize */
     nullptr, /* tp_dealloc */
     0, /* tp_vectorcall_offset */
@@ -92,7 +94,7 @@
     0, /* tp_weaklistoffset */
     nullptr, /* tp_iter */
     nullptr, /* tp_iternext */
-    DisableTorchFunction_methods, /* tp_methods */
+    DisableTorchFunctionSubclass_methods, /* tp_methods */
     nullptr, /* tp_members */
     nullptr, /* tp_getset */
     nullptr, /* tp_base */
@@ -105,12 +107,12 @@
     PyType_GenericNew, /* tp_new */
 };
 
-PyObject* THPModule_DisableTorchFunctionType() {
-  if (PyType_Ready(&DisableTorchFunctionType) < 0) {
+PyObject* THPModule_DisableTorchFunctionSubclassType() {
+  if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
     return nullptr;
   }
 
-  return (PyObject*)(&DisableTorchFunctionType);
+  return (PyObject*)(&DisableTorchFunctionSubclassType);
 }
 
 PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h
index 3cdc33e..881a7ad 100644
--- a/torch/csrc/utils/disable_torch_function.h
+++ b/torch/csrc/utils/disable_torch_function.h
@@ -29,7 +29,7 @@
 } // namespace torch
 
 PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
-PyObject* THPModule_DisableTorchFunctionType();
+PyObject* THPModule_DisableTorchFunctionSubclassType();
 PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
 PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args);
 PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg);
diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py
index 08aa132..42d6592 100644
--- a/torch/distributed/_shard/common_op_utils.py
+++ b/torch/distributed/_shard/common_op_utils.py
@@ -53,11 +53,11 @@
         Handles ``__torch_function__`` dispatch for the default tensor ops that
         behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
         ``torch.Tensor.dtype``. We simply lower to the real op call with
-        DisableTorchFunction context like ``torch.Tensor.__torch_function__``
+        DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
         to avoid recursions.
         """
         if kwargs is None:
             kwargs = {}
 
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             return op(*args, **kwargs)
diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
index dc8d09b..6a48163 100644
--- a/torch/distributed/_shard/partial_tensor.py
+++ b/torch/distributed/_shard/partial_tensor.py
@@ -236,7 +236,7 @@
         # Need to disable all dispatch to print args and kwargs appropriately.
         guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
         try:
-            with torch._C.DisableTorchFunction():
+            with torch._C.DisableTorchFunctionSubclass():
                 raise RuntimeError(
                     f"torch function '{func.__name__}', with args: {args} and "
                     f"kwargs: {kwargs} not supported for PartialTensor!")
diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py
index 1327f89..e3db6b0 100644
--- a/torch/distributed/_shard/replicated_tensor.py
+++ b/torch/distributed/_shard/replicated_tensor.py
@@ -109,7 +109,7 @@
         # We cann't do super().__torch_function__() as it implicitly convert the result
         # back to tensor subclasses, where in our case, we need to control the output type
         # base on the inter-op rules we defined.
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             rs = func(*args, **kwargs)
             if func in get_default_nowrap_functions():
                 return rs
@@ -157,7 +157,7 @@
         return True
 
     def __setstate__(self, state):
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             self.data = state
             self.requires_grad = state.requires_grad
             from torch.distributed._shard.api import _get_current_process_group
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
index e52c292..9ed83ee 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
@@ -203,7 +203,7 @@
         local_shard.tensor.requires_grad_(requires_grad)
 
         # update the wrapper class property
-    with torch._C.DisableTorchFunction():
+    with torch._C.DisableTorchFunctionSubclass():
         self_st.requires_grad_(requires_grad)
     # update the metadata in the meanwhile
     self_st._metadata.tensor_properties.requires_grad = requires_grad
diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py
index 3274ef2..0459f24 100644
--- a/torch/masked/maskedtensor/core.py
+++ b/torch/masked/maskedtensor/core.py
@@ -270,7 +270,7 @@
 
         if not all(issubclass(cls, t) for t in types):
             return NotImplemented
-        with torch._C.DisableTorchFunction():
+        with torch._C.DisableTorchFunctionSubclass():
             ret = func(*args, **kwargs)
             if func in get_default_nowrap_functions():
                 return ret