wrap AliasDb in Python (#51336)
Summary:
Also added a wrapper tlemo 's graphviz export to string.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51336
Reviewed By: ezyang
Differential Revision: D26150809
Pulled By: eellison
fbshipit-source-id: 9beafce5cbdc1785b986b71c3cd986c1087faa11
diff --git a/test/jit/test_python_bindings.py b/test/jit/test_python_bindings.py
index b5fa9d3..9a37917 100644
--- a/test/jit/test_python_bindings.py
+++ b/test/jit/test_python_bindings.py
@@ -65,3 +65,13 @@
list(i)
o = test_iterator_keepalive_fn.inlined_graph.outputs()
list(o)
+
+ def test_aliasdb(self):
+ @torch.jit.script
+ def test_aliasdb_fn(x: torch.Tensor):
+ return 2 * x
+
+ gr = test_aliasdb_fn.graph.copy()
+ alias_db = gr.alias_db()
+ self.assertTrue("WILDCARD" in str(alias_db))
+ self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str())
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 2bd2f41..3bc5b57 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -365,9 +365,15 @@
class GraphExecutorState:
...
+# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
+class AliasDb:
+ def __str__(self) -> str: ...
+ ...
+
# Defined in torch/torch/csrc/jit/ir/ir.h
class Graph:
def eraseInput(self, i: _int) -> None: ...
+ def alias_db(self) -> AliasDb: ...
...
# Defined in torch/csrc/jit/ir/ir.h
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 7275d26..f0dace3 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -435,7 +435,7 @@
dot << toString();
dot << "*/\n";
- dot << "digraph fusion_ir {\n"
+ dot << "digraph alias_db {\n"
<< " rankdir=LR\n"
<< " node [shape=rect, color=gray];\n"
<< " edge [color=black];\n";
diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp
index cfaf1cf..c9c7272 100644
--- a/torch/csrc/jit/python/python_ir.cpp
+++ b/torch/csrc/jit/python/python_ir.cpp
@@ -214,6 +214,12 @@
void initPythonIRBindings(PyObject* module_) {
auto m = py::handle(module_).cast<py::module>();
+
+ py::class_<AliasDb, std::shared_ptr<AliasDb>>(m, "AliasDb")
+ .def("dump", &AliasDb::dump)
+ .def("to_graphviz_str", &AliasDb::toGraphviz)
+ .def("__str__", &AliasDb::toString);
+
#define GS(name) def(#name, &Graph ::name)
py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
.def(py::init<>())
@@ -228,6 +234,11 @@
[&](const bool enabled) { global_print_source_ranges = enabled; },
py::arg("enabled") = true)
.def(
+ "alias_db",
+ [](std::shared_ptr<Graph> g) {
+ return std::make_shared<AliasDb>(std::move(g));
+ })
+ .def(
"dump_alias_db",
[](std::shared_ptr<Graph> g) {
AliasDb db(std::move(g));