Allow specifying alias analysis while registering new ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77690

Approved by: https://github.com/ezyang
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index eeeb29a..9aa5034 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -5,7 +5,8 @@
 from copy import deepcopy
 from torch.library import Library
 from torch.cuda.jiterator import _create_jit_fn
-from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM
+import unittest
+from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS
 from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
     log_input, capture_logs, no_dispatch
 from torch.utils._pytree import tree_map
@@ -247,6 +248,28 @@
         del my_lib2
         del my_lib1
 
+    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
+    def test_alias_analysis(self):
+        def test_helper(alias_analysis=""):
+            my_lib1 = Library("foo", "DEF")
+
+            called = [0]
+
+            @torch.library.define(my_lib1, "_op() -> None", alias_analysis=alias_analysis)
+            def _op(*args, **kwargs):
+                called[0] += 1
+
+            @torch.jit.script
+            def _test():
+                torch.ops.foo._op()
+
+            assert "foo::_op" in str(_test.graph)
+
+        with self.assertRaises(AssertionError):
+            test_helper("")  # alias_analysis="FROM_SCHEMA"
+
+        test_helper("CONSERVATIVE")
+
 class TestPythonDispatch(TestCase):
     def test_basic(self) -> None:
         with capture_logs() as logs:
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 34a867f..eeb5b02 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -64,7 +64,7 @@
     py::gil_scoped_acquire g;
     auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
     auto obj = py::reinterpret_steal<py::object>(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr()));
-    if (obj == nullptr) { throw python_error(); }
+    if (!obj) { throw python_error(); }
     pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
   }
 };
@@ -149,9 +149,10 @@
       );
       END_HANDLE_TH_ERRORS_PYBIND
     }, "", py::arg("name"), py::arg("dispatch"), py::arg("func"))
-    .def("define", [](py::object self, const char* schema) {
-      self.cast<torch::Library&>().def(torch::schema(schema, c10::AliasAnalysisKind::FROM_SCHEMA));
-    }, "", py::arg("schema"))
+    .def("define", [](py::object self, const char* schema, const char* alias_analysis) {
+      self.cast<torch::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias_analysis)));
+      return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)).name();
+    }, "", py::arg("schema"), py::arg("alias_analysis") = "")
     .def("fallback_fallthrough", [](py::object self, const char* dispatch) {
       self.cast<torch::Library&>().fallback(
         dispatch_str(dispatch, CppFunction::makeFallthrough())
diff --git a/torch/library.py b/torch/library.py
index 03e1d73..0a07695 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -3,7 +3,7 @@
 import traceback
 import torch
 
-__all__ = ['Library', 'impl']
+__all__ = ['Library', 'impl', 'define']
 
 # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
 # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
@@ -39,8 +39,6 @@
 
     def impl(self, op_name, fn, dispatch_key=''):
         if dispatch_key == '':
-            if self.dispatch_key == '':
-                raise RuntimeError("Please specify the dispatch key that you want to register the kernel for.")
             dispatch_key = self.dispatch_key
 
         if isinstance(op_name, str):
@@ -65,8 +63,19 @@
         _impls.add(key)
         self._op_impls.add(key)
 
-    def define(self, schema):
-        self.m.define(schema)
+    def define(self, schema, alias_analysis=""):
+        '''
+        Takes a schema to define a new operator.
+        Also, optionally takes `alias_analysis` argument to indicate if the aliasing properties of the arguments
+        can be inferred from the schema (default behavior) or not ("CONSERVATIVE").
+
+        Returns the name of the operator as inferred from the schema.
+        '''
+        # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
+        # AliasAnalysis type in C++
+        if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
+            raise RuntimeError("Invalid alias_analysis type")
+        return self.m.define(schema, alias_analysis)
 
     def __del__(self):
         for key in self._op_impls:
@@ -75,7 +84,13 @@
 
 # decorator to register python functions for library ops
 # Note: this decorator API should remain consistent with `Library.impl` API
-def impl(lib, name, dispatch_key=''):
+def impl(lib, name, dispatch_key=""):
     def wrap(f):
         lib.impl(name, f, dispatch_key)
     return wrap
+
+def define(lib, schema, alias_analysis=""):
+    def wrap(f):
+        name = lib.define(schema, alias_analysis)
+        lib.impl(name, f)
+    return wrap