blob: edbfa0803d42a62b4819f4b43bfe237b88492f0f [file] [log] [blame]
# Owner(s): ["oncall: export"]
import unittest
from typing import Any, Dict, Optional, OrderedDict, Tuple
import torch
from torch._export.passes.lift_constants_pass import (
ConstantAttrMap,
lift_constants_pass,
)
from torch.export._unlift import _unlift_exported_program_lifted_states
from torch.export.exported_program import (
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch.export.graph_signature import CustomObjArgument
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TestCase,
)
class GraphBuilder:
def __init__(self):
self.graph = torch.fx.Graph()
self.nodes = {}
self.values = {}
self.nn_module_stack_key: Dict[str, int] = {}
self.latest_id = 0
self.input_to_kind: Dict[torch.fx.Node, InputKind] = {}
def input(self, name: str, value: torch.Tensor, kind: InputKind):
node = self.graph.placeholder(name)
node.meta["val"] = value
self.nodes[name] = node
self.values[name] = value
self.input_to_kind[node] = kind
def add(self, x: str, y: str, out: str, module_fqn: str = ""):
node = self.graph.create_node(
"call_function",
torch.ops.aten.add.Tensor,
(self.nodes[x], self.nodes[y]),
name=out,
)
self.values[out] = self.values[x] + self.values[y]
node.meta["val"] = self.values[out]
node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
self.nodes[out] = node
def call_function(self, target, args, out: str, module_fqn: str = ""):
arg_nodes = tuple(self.nodes[arg] for arg in args)
arg_values = tuple(self.values[arg] for arg in args)
node = self.graph.create_node(
"call_function",
target,
arg_nodes,
name=out,
)
self.values[out] = target(*arg_values)
node.meta["val"] = self.values[out]
node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
self.nodes[out] = node
def constant(
self, name: str, value: Any, target: Optional[str] = None, module_fqn: str = ""
):
if target is None:
target = name
node = self.graph.get_attr(target)
node.meta["val"] = value
node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
self.nodes[name] = node
self.values[name] = value
def output(self, out: str):
self.graph.output(self.nodes[out])
def create_nn_module_stack(
self, module_fqn: str
) -> OrderedDict[int, Tuple[str, type]]:
cur_name = ""
nn_module_stack = OrderedDict()
for atom in module_fqn.split("."):
if cur_name == "":
cur_name = atom
else:
cur_name = cur_name + "." + atom
if cur_name not in self.nn_module_stack_key:
id_counter = self.latest_id
self.latest_id += 1
self.nn_module_stack_key[cur_name] = id_counter
else:
id_counter = self.nn_module_stack_key[cur_name]
nn_module_stack[id_counter] = (cur_name, torch.nn.Module)
return nn_module_stack
def create_input_specs(self):
input_specs = []
for node in self.graph.nodes:
if node.op == "placeholder":
input_specs.append(
InputSpec(
kind=self.input_to_kind[node],
arg=TensorArgument(name=node.name),
target=None,
persistent=True
if self.input_to_kind[node] == InputKind.BUFFER
else None,
)
)
return input_specs
# NOTE: does not handle non-user-outputs atm
def gen_graph_signature(self) -> ExportGraphSignature:
output = [n for n in self.graph.nodes if n.op == "output"]
assert len(output) == 1
output = output[0]
assert len(output.args) == 1, "multiple outputs NYI"
return ExportGraphSignature(
input_specs=self.create_input_specs(),
output_specs=[
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=n.name),
target=None,
)
for n in output.args
],
)
class TestLift(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_lift_basic(self):
builder = GraphBuilder()
builder.input("param", torch.rand(2, 3), InputKind.PARAMETER)
builder.input("buffer", torch.rand(2, 3), InputKind.BUFFER)
builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)
builder.add("x", "y", out="foo")
builder.add("foo", "param", out="bar")
builder.add("bar", "buffer", out="baz")
builder.constant("const_tensor", torch.rand(2, 3))
builder.constant("const_obj", torch.classes._TorchScriptTesting._Foo(10, 20))
builder.add("baz", "const_tensor", out="out")
builder.call_function(
torch.ops._TorchScriptTesting.takes_foo,
("const_obj", "x"),
out="torchbind_out",
)
builder.add("out", "torchbind_out", out="final_out")
builder.output("final_out")
builder.graph.lint()
graph = builder.graph
const_tensor = builder.values["const_tensor"]
const_obj = builder.values["const_obj"]
root = {"const_tensor": const_tensor, "const_obj": const_obj}
gm = torch.fx.GraphModule(root, graph)
graph_signature = builder.gen_graph_signature()
constants = lift_constants_pass(gm, graph_signature, {})
gm.graph.lint()
self.assertEqual(len(constants), 2)
# The key of the constants table should match the fqn of the constant.
# In this case, it's just the name of the constant, since the constant
# is at the root submodule.
# TODO(suo): we shouldn't hardcode these names in the test, this is an
# internal detail of the pass.
self.assertIn("_lifted_tensor_constant0", constants)
self.assertEqual(constants["_lifted_tensor_constant0"], const_tensor)
self.assertIn("_lifted_custom_obj0", constants)
self.assertEqual(constants["_lifted_custom_obj0"], const_obj)
# The constant node should be removed.
getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
self.assertEqual(len(getattr_nodes), 0)
# The constant should be lifted to a placeholder node.
placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
self.assertEqual(len(placeholder_nodes), 6)
# The lifted constant should be placed before user inputs but after params/buffers
lifted_tensor_placeholder = placeholder_nodes[2]
self.assertEqual(lifted_tensor_placeholder.target, "_lifted_tensor_constant0")
# It should have a val equivalent to the constant
self.assertEqual(lifted_tensor_placeholder.meta["val"], const_tensor)
lifted_obj_placeholder = placeholder_nodes[3]
self.assertEqual(lifted_obj_placeholder.target, "_lifted_custom_obj0")
# It should have a val equivalent to the constant
self.assertEqual(
lifted_obj_placeholder.meta["val"],
CustomObjArgument(
name="_lifted_custom_obj0",
class_fqn="__torch__.torch.classes._TorchScriptTesting._Foo",
),
)
# Graph signature should have been mutated a way that reflects the placeholders.
tensor_constant_input_spec = graph_signature.input_specs[2]
self.assertEqual(tensor_constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
self.assertIsInstance(tensor_constant_input_spec.arg, TensorArgument)
self.assertEqual(
tensor_constant_input_spec.arg.name, lifted_tensor_placeholder.name
)
obj_constant_input_spec = graph_signature.input_specs[3]
self.assertEqual(obj_constant_input_spec.kind, InputKind.CUSTOM_OBJ)
self.assertIsInstance(obj_constant_input_spec.arg, CustomObjArgument)
self.assertEqual(obj_constant_input_spec.arg.name, lifted_obj_placeholder.name)
def test_lift_nested(self):
builder = GraphBuilder()
builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)
builder.input("z", torch.rand(2, 3), InputKind.USER_INPUT)
builder.add("x", "y", out="foo")
builder.add("foo", "z", out="bar", module_fqn="foo")
builder.constant("const_tensor", torch.rand(2, 3), module_fqn="foo")
builder.add("bar", "const_tensor", "out")
builder.output("out")
graph = builder.graph
graph.lint()
const_tensor = builder.values["const_tensor"]
root = {"const_tensor": builder.values["const_tensor"]}
graph_signature = builder.gen_graph_signature()
gm = torch.fx.GraphModule(root, graph)
constants = lift_constants_pass(gm, graph_signature, {})
gm.graph.lint()
self.assertEqual(len(constants), 1)
# The key of the constants table should match the fqn of the constant.
self.assertIn("foo._lifted_tensor_constant0", constants)
self.assertEqual(constants["foo._lifted_tensor_constant0"], const_tensor)
# The constant node should be removed.
getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
self.assertEqual(len(getattr_nodes), 0)
# The constant should be lifted to a placeholder node.
placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
self.assertEqual(len(placeholder_nodes), 4)
# The lifted constant should be placed before user inputs but after params/buffers
lifted_constant_placeholder = placeholder_nodes[0]
self.assertEqual(lifted_constant_placeholder.target, "_lifted_tensor_constant0")
# Graph signature should have been mutated a way that reflects the placeholders.
constant_input_spec = graph_signature.input_specs[0]
self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
self.assertIsInstance(constant_input_spec.arg, TensorArgument)
self.assertEqual(constant_input_spec.arg.name, lifted_constant_placeholder.name)
def test_duplicate_constant_access(self):
const = torch.rand(2, 3)
const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
builder = GraphBuilder()
builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
builder.constant("const_tensor", const, target="const_tensor")
# loading the same target twice
builder.constant("const_tensor2", const, target="const_tensor")
# loading the same object twice with different targets
builder.constant("const_obj", const_obj)
builder.constant("const_obj2", const_obj)
builder.call_function(
torch.ops._TorchScriptTesting.takes_foo,
("const_obj", "x"),
out="torchbind_out",
)
builder.call_function(
torch.ops._TorchScriptTesting.takes_foo,
("const_obj2", "x"),
out="torchbind_out2",
)
builder.add("x", "const_tensor", out="foo")
builder.add("foo", "const_tensor2", out="tensor_out")
builder.add("torchbind_out", "torchbind_out2", out="obj_out")
builder.add("tensor_out", "obj_out", out="out")
builder.output("out")
graph = builder.graph
graph.lint()
input_specs = builder.create_input_specs()
output_specs = [
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=builder.nodes["out"].name),
target=None,
)
]
graph_signature = ExportGraphSignature(input_specs, output_specs)
root = {"const_tensor": const, "const_obj": const_obj, "const_obj2": const_obj}
gm = torch.fx.GraphModule(root, graph)
constants = lift_constants_pass(gm, graph_signature, {})
gm.graph.lint()
self.assertEqual(len(constants), 2)
# All get_attr nodes should be removed
getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
self.assertEqual(len(getattr_nodes), 0)
# There should only be two additional inputs (plus the existing user input)
placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
self.assertEqual(len(placeholder_nodes), 3)
# Graph signature should have been mutated a way that reflects the placeholders.
self.assertEqual(len(graph_signature.input_specs), 3)
constant_input_spec = graph_signature.input_specs[0]
self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
self.assertIsInstance(constant_input_spec.arg, TensorArgument)
def test_unlift_nonpersistent_buffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"non_persistent_buf", torch.zeros(1), persistent=False
)
def forward(self, x):
self.non_persistent_buf.add_(1)
return x.sum() + self.non_persistent_buf.sum()
foo = Foo()
exported = torch.export.export(foo, (torch.ones(5, 5),), strict=False)
stateful_gm = _unlift_exported_program_lifted_states(exported)
# Check the unlifted stateful_gm contains the original non-persistent buffer
self.assertTrue(hasattr(stateful_gm, "non_persistent_buf"))
non_persistent_buf = stateful_gm.get_buffer("non_persistent_buf")
self.assertEqual(non_persistent_buf, foo.get_buffer("non_persistent_buf"))
self.assertIn("non_persistent_buf", stateful_gm._non_persistent_buffers_set)
self.assertNotIn("non_persistent_buf", stateful_gm.state_dict())
class ConstantAttrMapTest(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_dict_api(self):
constant_attr_map = ConstantAttrMap()
const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
const_tensor = torch.ones(2, 3)
constant_attr_map[const_obj] = "foo.bar"
constant_attr_map[const_tensor] = "foo.bar.baz"
self.assertEqual(len(constant_attr_map), 2)
self.assertEqual(list(constant_attr_map), [const_obj, const_tensor])
self.assertEqual(list(constant_attr_map.keys()), [const_obj, const_tensor])
self.assertEqual(list(constant_attr_map.values()), ["foo.bar", "foo.bar.baz"])
self.assertEqual(constant_attr_map[const_obj], "foo.bar")
self.assertEqual(constant_attr_map[const_tensor], "foo.bar.baz")
self.assertTrue(const_obj in constant_attr_map)
with self.assertRaises(TypeError):
constant_attr_map[1] = "foo.bar"
del constant_attr_map[const_obj]
self.assertEqual(len(constant_attr_map), 1)
if __name__ == "__main__":
run_tests()