Change HigherOrderOperator default namespace from global to 'higher_order' (#103870)
This PR changes the default namespace for higher order operators from the
global namespace (e.g. torch.ops.cond) to `higher_order` (e.g.
torch.ops.higher_order.cond). We don't actually change the namespace
for existing HigherOrderOperators.
The motivation is to stem the bleeding; exposing operators into the global
namespace is a bad idea due to name collision with other user-defined
namespaces.
We will go in and fix the `_deprecated_global_ns` as necessary after this diff.
Differential Revision: [D46809738](https://our.internmc.facebook.com/intern/diff/D46809738/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103870
Approved by: https://github.com/ydwu4
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 088aa67..270c59b 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -34,7 +34,7 @@
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
-cond = HigherOrderOperator("cond")
+cond = HigherOrderOperator("cond", _deprecated_global_ns=True)
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index d13aa9c..f4c479a 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -29,8 +29,8 @@
def __call__(self, xs, *args):
return map_wrapper(xs, *args)
-map = MapWrapper("map")
-map_impl = HigherOrderOperator("map_impl")
+map = MapWrapper("map", _deprecated_global_ns=True)
+map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
dummy_aot_config = AOTConfig(fw_compiler=None,
bw_compiler=None,
diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py
index 2a05fb8..db83e86 100644
--- a/torch/_higher_order_ops/wrap.py
+++ b/torch/_higher_order_ops/wrap.py
@@ -6,7 +6,7 @@
# Used for testing the HigherOrderOperator mechanism
class Wrap(HigherOrderOperator):
def __init__(self):
- super().__init__("wrap")
+ super().__init__("wrap", _deprecated_global_ns=True)
def __call__(self, func, *args):
result = func(*args)
@@ -21,7 +21,7 @@
checkpointed (the first arg to the utils.checkpoint() function).
"""
def __init__(self):
- super().__init__("wrap_activation_checkpoint")
+ super().__init__("wrap_activation_checkpoint", _deprecated_global_ns=True)
def __call__(self, function, *args, **kwargs):
# use_reentrant is set to False because this op is going to be traced.
@@ -47,7 +47,7 @@
"""
def __init__(self):
- super().__init__("wrap_activation_checkpoint")
+ super().__init__("wrap_activation_checkpoint", _deprecated_global_ns=True)
def tag_nodes(self, gmod):
# TODO - This needs major investigation. Currently, we are tagging all
diff --git a/torch/_ops.py b/torch/_ops.py
index 811a263..29b537d 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -190,19 +190,36 @@
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
-pyop_namespace = {}
+_global_higher_order_ops = {}
+_higher_order_ops = {}
class HigherOrderOperator(OperatorBase):
- def __init__(self, name):
+ # _deprecated_global_ns: Whether or not the HigherOrderOperator appears as:
+ # (True) torch.ops.{name}
+ # (False) torch.ops.higher_order.{name}
+ #
+ # If you're creating a new HigherOrderOperator, please do not change the
+ # default. Adding operators to the global torch.ops namespace is a bad
+ # practice due to name collisions.
+ def __init__(self, name, *, _deprecated_global_ns=False):
super().__init__()
self._name = name
# Make _OPNamespace not scream, this whole name based association needs a good hard look
self.__name__ = name
- pyop_namespace[name] = self
+ if _deprecated_global_ns:
+ _global_higher_order_ops[name] = self
+ self._ns = None
+ else:
+ _higher_order_ops[name] = self
+ self._ns = "higher_order"
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
+ @property
+ def namespace(self):
+ return self._ns
+
def fallthrough(self, dispatch_key):
self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
@@ -744,9 +761,12 @@
class _PyOpNamespace(_OpNamespace):
- def __init__(self):
- super().__init__("torch.ops")
- self.pyop_namespace = pyop_namespace
+ def __init__(self, name, ops):
+ super().__init__(name)
+ self._ops = ops
+
+ def __getattr__(self, name):
+ return self._ops[name]
class _Ops(types.ModuleType):
@@ -755,13 +775,20 @@
def __init__(self):
super().__init__("torch.ops")
self.loaded_libraries = set()
- self.pyops = _PyOpNamespace()
+ self._global_higher_order_op_namespace = _PyOpNamespace(
+ "torch.ops", _global_higher_order_ops
+ )
+ self._higher_order_op_namespace = _PyOpNamespace(
+ "torch.ops.higher_order", _higher_order_ops
+ )
self._dir = []
def __getattr__(self, name):
- # Check if the name is a pyop
- if name in self.pyops.pyop_namespace:
- return self.pyops.pyop_namespace[name]
+ # Check if the name is a HigherOrderOperator
+ if name in self._global_higher_order_op_namespace._ops:
+ return getattr(self._global_higher_order_op_namespace, name)
+ if name == "higher_order":
+ return self._higher_order_op_namespace
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index d4bcd98..79653df 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -164,7 +164,9 @@
def register_run_and_save_rng_state_op():
- run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
+ run_and_save_rng_state = HigherOrderOperator(
+ "run_and_save_rng_state", _deprecated_global_ns=True
+ )
run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
@@ -218,7 +220,9 @@
def register_run_with_rng_state_op():
- run_with_rng_state = HigherOrderOperator("run_with_rng_state")
+ run_with_rng_state = HigherOrderOperator(
+ "run_with_rng_state", _deprecated_global_ns=True
+ )
run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]