Allow specifying alias analysis while registering new ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77690
Approved by: https://github.com/ezyang
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index eeeb29a..9aa5034 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -5,7 +5,8 @@
from copy import deepcopy
from torch.library import Library
from torch.cuda.jiterator import _create_jit_fn
-from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM
+import unittest
+from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, no_dispatch
from torch.utils._pytree import tree_map
@@ -247,6 +248,28 @@
del my_lib2
del my_lib1
+ @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
+ def test_alias_analysis(self):
+ def test_helper(alias_analysis=""):
+ my_lib1 = Library("foo", "DEF")
+
+ called = [0]
+
+ @torch.library.define(my_lib1, "_op() -> None", alias_analysis=alias_analysis)
+ def _op(*args, **kwargs):
+ called[0] += 1
+
+ @torch.jit.script
+ def _test():
+ torch.ops.foo._op()
+
+ assert "foo::_op" in str(_test.graph)
+
+ with self.assertRaises(AssertionError):
+ test_helper("") # alias_analysis="FROM_SCHEMA"
+
+ test_helper("CONSERVATIVE")
+
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 34a867f..eeb5b02 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -64,7 +64,7 @@
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr()));
- if (obj == nullptr) { throw python_error(); }
+ if (!obj) { throw python_error(); }
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
};
@@ -149,9 +149,10 @@
);
END_HANDLE_TH_ERRORS_PYBIND
}, "", py::arg("name"), py::arg("dispatch"), py::arg("func"))
- .def("define", [](py::object self, const char* schema) {
- self.cast<torch::Library&>().def(torch::schema(schema, c10::AliasAnalysisKind::FROM_SCHEMA));
- }, "", py::arg("schema"))
+ .def("define", [](py::object self, const char* schema, const char* alias_analysis) {
+ self.cast<torch::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias_analysis)));
+ return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)).name();
+ }, "", py::arg("schema"), py::arg("alias_analysis") = "")
.def("fallback_fallthrough", [](py::object self, const char* dispatch) {
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, CppFunction::makeFallthrough())
diff --git a/torch/library.py b/torch/library.py
index 03e1d73..0a07695 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -3,7 +3,7 @@
import traceback
import torch
-__all__ = ['Library', 'impl']
+__all__ = ['Library', 'impl', 'define']
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
@@ -39,8 +39,6 @@
def impl(self, op_name, fn, dispatch_key=''):
if dispatch_key == '':
- if self.dispatch_key == '':
- raise RuntimeError("Please specify the dispatch key that you want to register the kernel for.")
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
@@ -65,8 +63,19 @@
_impls.add(key)
self._op_impls.add(key)
- def define(self, schema):
- self.m.define(schema)
+ def define(self, schema, alias_analysis=""):
+ '''
+ Takes a schema to define a new operator.
+ Also, optionally takes `alias_analysis` argument to indicate if the aliasing properties of the arguments
+ can be inferred from the schema (default behavior) or not ("CONSERVATIVE").
+
+ Returns the name of the operator as inferred from the schema.
+ '''
+ # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
+ # AliasAnalysis type in C++
+ if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
+ raise RuntimeError("Invalid alias_analysis type")
+ return self.m.define(schema, alias_analysis)
def __del__(self):
for key in self._op_impls:
@@ -75,7 +84,13 @@
# decorator to register python functions for library ops
# Note: this decorator API should remain consistent with `Library.impl` API
-def impl(lib, name, dispatch_key=''):
+def impl(lib, name, dispatch_key=""):
def wrap(f):
lib.impl(name, f, dispatch_key)
return wrap
+
+def define(lib, schema, alias_analysis=""):
+ def wrap(f):
+ name = lib.define(schema, alias_analysis)
+ lib.impl(name, f)
+ return wrap