rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)
First half of #87990. This doesn't change any of the behavior and is just a rename
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218
Approved by: https://github.com/ezyang, https://github.com/zou3519
diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json
index ba4a2e9..8a66dc1 100644
--- a/test/allowlist_for_publicAPI.json
+++ b/test/allowlist_for_publicAPI.json
@@ -1128,7 +1128,7 @@
"BFloat16Tensor",
"ComplexDoubleStorage",
"ComplexFloatStorage",
- "DisableTorchFunction",
+ "DisableTorchFunctionSubclass",
"Generator",
"HalfStorage",
"HalfTensor",
diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py
index d4a31c6..2105302 100644
--- a/test/profiler/test_profiler_tree.py
+++ b/test/profiler/test_profiler_tree.py
@@ -26,7 +26,7 @@
"torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
- "<built-in method __exit__ of torch._C.DisableTorchFunction object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
+ "<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
"cudaStreamIsCapturing": PRUNE_ALL,
"cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL,
}
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 7082f75..01c763a 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -1448,7 +1448,7 @@
x = B(torch.randn(5))
with A():
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
self.assertNotIsInstance(torch.sum(x), B)
self.assertTrue(called)
@@ -1460,7 +1460,7 @@
pass
x = A(torch.randn(5))
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
g = torch._C._EnableTorchFunction()
try:
self.assertIsInstance(torch.sum(x), A)
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 4d2df65..6897c31 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -99,7 +99,7 @@
"device",
"DeviceObjType",
"DictType",
- "DisableTorchFunction",
+ "DisableTorchFunctionSubclass",
"DispatchKey",
"DispatchKeySet",
"dtype",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 2d20da2..79dd638 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -108,7 +108,7 @@
...
# Defined in torch/csrc/utils/disable_torch_function.cpp
-def DisableTorchFunction(): ...
+def DisableTorchFunctionSubclass(): ...
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
diff --git a/torch/__init__.py b/torch/__init__.py
index ae55f59..ef6138c 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -288,7 +288,7 @@
if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
if (obj.__module__ != 'torch'):
# TODO: fix their module from C++ side
- if name not in ['DisableTorchFunction', 'Generator']:
+ if name not in ['DisableTorchFunctionSubclass', 'Generator']:
obj.__module__ = 'torch'
if not TYPE_CHECKING:
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index d3c5140..9d87897 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -506,7 +506,7 @@
)
# Disable __torch_function__ to prevent cloning of `value` to hit
# us
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
if is_constant_source(self.get_source()):
return self.tx.output.register_attr_or_module(
value,
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index da32712..6e4325b 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -538,7 +538,7 @@
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], TorchVariable):
return TensorVariable.create(
tx=tx,
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 315c2b1..5a30f83 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -704,7 +704,7 @@
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
return tx.inline_user_function_return(tf_func_var, tf_args, {})
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 14f5cd2..79af51e 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -1093,5 +1093,5 @@
memo[id(tensor)] = out
return out
else:
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 793034b..41b6569 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -1297,7 +1297,7 @@
if not all(issubclass(cls, t) for t in types):
return NotImplemented
- with _C.DisableTorchFunction():
+ with _C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index b8693a4..efe6c18 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -1594,8 +1594,8 @@
(PyObject*)THPDefaultCPUGenerator,
/* incref= */ false));
ASSERT_TRUE(set_module_attr(
- "DisableTorchFunction",
- (PyObject*)THPModule_DisableTorchFunctionType(),
+ "DisableTorchFunctionSubclass",
+ (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
/* incref= */ false));
torch::set_disabled_torch_function_impl(
PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index ee96323..d26db95 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -343,7 +343,6 @@
_C_m, "_RestorePythonTLSSnapshot")
.def(py::init<>());
- // TODO: line up this binding with DisableTorchFunction
py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
.def(py::init<>());
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp
index 682120d..516e6b8 100644
--- a/torch/csrc/utils/disable_torch_function.cpp
+++ b/torch/csrc/utils/disable_torch_function.cpp
@@ -35,18 +35,20 @@
PyObject_HEAD
/* Type-specific fields go here. */
bool old_state;
-} DisableTorchFunction;
+} DisableTorchFunctionSubclass;
-PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
- ((DisableTorchFunction*)self)->old_state =
+PyObject* DisableTorchFunctionSubclass__enter(
+ PyObject* self,
+ PyObject* unused) {
+ ((DisableTorchFunctionSubclass*)self)->old_state =
at::impl::PythonTorchFunctionTLS::is_disabled();
at::impl::PythonTorchFunctionTLS::set_disabled(true);
Py_RETURN_NONE;
}
-PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
+PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
at::impl::PythonTorchFunctionTLS::set_disabled(
- ((DisableTorchFunction*)self)->old_state);
+ ((DisableTorchFunctionSubclass*)self)->old_state);
Py_RETURN_NONE;
}
@@ -58,16 +60,16 @@
}
}
-static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
- {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
- {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
+static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
+ {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
+ {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
-PyTypeObject DisableTorchFunctionType = {
+PyTypeObject DisableTorchFunctionSubclassType = {
PyVarObject_HEAD_INIT(
nullptr,
- 0) "torch._C.DisableTorchFunction", /* tp_name */
- sizeof(DisableTorchFunction), /* tp_basicsize */
+ 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
+ sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
@@ -92,7 +94,7 @@
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
- DisableTorchFunction_methods, /* tp_methods */
+ DisableTorchFunctionSubclass_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
@@ -105,12 +107,12 @@
PyType_GenericNew, /* tp_new */
};
-PyObject* THPModule_DisableTorchFunctionType() {
- if (PyType_Ready(&DisableTorchFunctionType) < 0) {
+PyObject* THPModule_DisableTorchFunctionSubclassType() {
+ if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
return nullptr;
}
- return (PyObject*)(&DisableTorchFunctionType);
+ return (PyObject*)(&DisableTorchFunctionSubclassType);
}
PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h
index 3cdc33e..881a7ad 100644
--- a/torch/csrc/utils/disable_torch_function.h
+++ b/torch/csrc/utils/disable_torch_function.h
@@ -29,7 +29,7 @@
} // namespace torch
PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
-PyObject* THPModule_DisableTorchFunctionType();
+PyObject* THPModule_DisableTorchFunctionSubclassType();
PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args);
PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg);
diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py
index 08aa132..42d6592 100644
--- a/torch/distributed/_shard/common_op_utils.py
+++ b/torch/distributed/_shard/common_op_utils.py
@@ -53,11 +53,11 @@
Handles ``__torch_function__`` dispatch for the default tensor ops that
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
``torch.Tensor.dtype``. We simply lower to the real op call with
- DisableTorchFunction context like ``torch.Tensor.__torch_function__``
+ DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
to avoid recursions.
"""
if kwargs is None:
kwargs = {}
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
return op(*args, **kwargs)
diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
index dc8d09b..6a48163 100644
--- a/torch/distributed/_shard/partial_tensor.py
+++ b/torch/distributed/_shard/partial_tensor.py
@@ -236,7 +236,7 @@
# Need to disable all dispatch to print args and kwargs appropriately.
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for PartialTensor!")
diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py
index 1327f89..e3db6b0 100644
--- a/torch/distributed/_shard/replicated_tensor.py
+++ b/torch/distributed/_shard/replicated_tensor.py
@@ -109,7 +109,7 @@
# We cann't do super().__torch_function__() as it implicitly convert the result
# back to tensor subclasses, where in our case, we need to control the output type
# base on the inter-op rules we defined.
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
rs = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return rs
@@ -157,7 +157,7 @@
return True
def __setstate__(self, state):
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
self.data = state
self.requires_grad = state.requires_grad
from torch.distributed._shard.api import _get_current_process_group
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
index e52c292..9ed83ee 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
@@ -203,7 +203,7 @@
local_shard.tensor.requires_grad_(requires_grad)
# update the wrapper class property
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
self_st.requires_grad_(requires_grad)
# update the metadata in the meanwhile
self_st._metadata.tensor_properties.requires_grad = requires_grad
diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py
index 3274ef2..0459f24 100644
--- a/torch/masked/maskedtensor/core.py
+++ b/torch/masked/maskedtensor/core.py
@@ -270,7 +270,7 @@
if not all(issubclass(cls, t) for t in types):
return NotImplemented
- with torch._C.DisableTorchFunction():
+ with torch._C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret