Change HigherOrderOperator default namespace from global to 'higher_order' (#103870)

This PR changes the default namespace for higher order operators from the
global namespace (e.g. torch.ops.cond) to `higher_order` (e.g.
torch.ops.higher_order.cond). We don't actually change the namespace
for existing HigherOrderOperators.

The motivation is to stem the bleeding; exposing operators into the global
namespace is a bad idea due to name collision with other user-defined
namespaces.

We will go in and fix the `_deprecated_global_ns` as necessary after this diff.

Differential Revision: [D46809738](https://our.internmc.facebook.com/intern/diff/D46809738/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103870
Approved by: https://github.com/ydwu4
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 088aa67..270c59b 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -34,7 +34,7 @@
 We're going to define a `cond` operation.
 In order to do this, we need implementations for each of the dispatch keys.
 """
-cond = HigherOrderOperator("cond")
+cond = HigherOrderOperator("cond", _deprecated_global_ns=True)
 
 def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
     assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index d13aa9c..f4c479a 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -29,8 +29,8 @@
     def __call__(self, xs, *args):
         return map_wrapper(xs, *args)
 
-map = MapWrapper("map")
-map_impl = HigherOrderOperator("map_impl")
+map = MapWrapper("map", _deprecated_global_ns=True)
+map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
 
 dummy_aot_config = AOTConfig(fw_compiler=None,
                              bw_compiler=None,
diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py
index 2a05fb8..db83e86 100644
--- a/torch/_higher_order_ops/wrap.py
+++ b/torch/_higher_order_ops/wrap.py
@@ -6,7 +6,7 @@
 # Used for testing the HigherOrderOperator mechanism
 class Wrap(HigherOrderOperator):
     def __init__(self):
-        super().__init__("wrap")
+        super().__init__("wrap", _deprecated_global_ns=True)
 
     def __call__(self, func, *args):
         result = func(*args)
@@ -21,7 +21,7 @@
     checkpointed (the first arg to the utils.checkpoint() function).
     """
     def __init__(self):
-        super().__init__("wrap_activation_checkpoint")
+        super().__init__("wrap_activation_checkpoint", _deprecated_global_ns=True)
 
     def __call__(self, function, *args, **kwargs):
         # use_reentrant is set to False because this op is going to be traced.
@@ -47,7 +47,7 @@
     """
 
     def __init__(self):
-        super().__init__("wrap_activation_checkpoint")
+        super().__init__("wrap_activation_checkpoint", _deprecated_global_ns=True)
 
     def tag_nodes(self, gmod):
         # TODO - This needs major investigation. Currently, we are tagging all
diff --git a/torch/_ops.py b/torch/_ops.py
index 811a263..29b537d 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -190,19 +190,36 @@
     raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
 
 
-pyop_namespace = {}
+_global_higher_order_ops = {}
+_higher_order_ops = {}
 
 
 class HigherOrderOperator(OperatorBase):
-    def __init__(self, name):
+    # _deprecated_global_ns: Whether or not the HigherOrderOperator appears as:
+    # (True) torch.ops.{name}
+    # (False) torch.ops.higher_order.{name}
+    #
+    # If you're creating a new HigherOrderOperator, please do not change the
+    # default. Adding operators to the global torch.ops namespace is a bad
+    # practice due to name collisions.
+    def __init__(self, name, *, _deprecated_global_ns=False):
         super().__init__()
         self._name = name
 
         # Make _OPNamespace not scream, this whole name based association needs a good hard look
         self.__name__ = name
-        pyop_namespace[name] = self
+        if _deprecated_global_ns:
+            _global_higher_order_ops[name] = self
+            self._ns = None
+        else:
+            _higher_order_ops[name] = self
+            self._ns = "higher_order"
         self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
 
+    @property
+    def namespace(self):
+        return self._ns
+
     def fallthrough(self, dispatch_key):
         self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
 
@@ -744,9 +761,12 @@
 
 
 class _PyOpNamespace(_OpNamespace):
-    def __init__(self):
-        super().__init__("torch.ops")
-        self.pyop_namespace = pyop_namespace
+    def __init__(self, name, ops):
+        super().__init__(name)
+        self._ops = ops
+
+    def __getattr__(self, name):
+        return self._ops[name]
 
 
 class _Ops(types.ModuleType):
@@ -755,13 +775,20 @@
     def __init__(self):
         super().__init__("torch.ops")
         self.loaded_libraries = set()
-        self.pyops = _PyOpNamespace()
+        self._global_higher_order_op_namespace = _PyOpNamespace(
+            "torch.ops", _global_higher_order_ops
+        )
+        self._higher_order_op_namespace = _PyOpNamespace(
+            "torch.ops.higher_order", _higher_order_ops
+        )
         self._dir = []
 
     def __getattr__(self, name):
-        # Check if the name is a pyop
-        if name in self.pyops.pyop_namespace:
-            return self.pyops.pyop_namespace[name]
+        # Check if the name is a HigherOrderOperator
+        if name in self._global_higher_order_op_namespace._ops:
+            return getattr(self._global_higher_order_op_namespace, name)
+        if name == "higher_order":
+            return self._higher_order_op_namespace
 
         # Here we are creating `torch.ops.my_namespace`
         namespace = _OpNamespace(name)
diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py
index d4bcd98..79653df 100644
--- a/torch/_prims/rng_prims.py
+++ b/torch/_prims/rng_prims.py
@@ -164,7 +164,9 @@
 
 
 def register_run_and_save_rng_state_op():
-    run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
+    run_and_save_rng_state = HigherOrderOperator(
+        "run_and_save_rng_state", _deprecated_global_ns=True
+    )
 
     run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
     run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher)  # type: ignore[attr-defined]
@@ -218,7 +220,9 @@
 
 
 def register_run_with_rng_state_op():
-    run_with_rng_state = HigherOrderOperator("run_with_rng_state")
+    run_with_rng_state = HigherOrderOperator(
+        "run_with_rng_state", _deprecated_global_ns=True
+    )
 
     run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
     run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot)  # type: ignore[attr-defined]