blob: f64a1f19981bf37adfff655912d7135009a4051b [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import unittest
import torch
import torch.nn as nn
from executorch.exir import memory, to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes import MemoryPlanningPass
class TestModel1(nn.Module):
def __init__(self):
super().__init__()
self.parameter = nn.Parameter(torch.rand(5, 6))
self.parameter.requires_grad = False
self.parameter2 = nn.Parameter(torch.rand(30))
self.parameter2.requires_grad = False
def forward(self, x):
v1 = self.parameter.view(
6, 5
) # removed, lifetime of parameter will be extended
v2 = x.view(6, 5) # not removed
v3 = torch.ops.aten.mul.Tensor(v1, v2).view(
30
) # removed, lifetime of mul.Tensor will be extended
v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2)
v5 = v4.view(6, 5) # not removed, output of the graph
v6 = v4.view(2, 15) # not removed, output of the graph
return v5, v6
def get_example_inputs(self):
return (torch.rand(5, 6),)
class TestRemoveViewCopy(unittest.TestCase):
def test_disable(self) -> None:
model = TestModel1()
model.eval()
example_inputs = model.get_example_inputs()
ep = torch.export.export(model, example_inputs)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=False,
memory_planning_pass=MemoryPlanningPass(
"greedy", alloc_graph_input=False
),
),
)
for node in etpm.exported_program().graph_module.graph.nodes:
assert node.target != memory.view
def test_output_matches(self) -> None:
model = TestModel1()
model.eval()
example_inputs = model.get_example_inputs()
ep = torch.export.export(model, example_inputs)
epm_remove = to_edge(ep)
epm_no_remove = copy.deepcopy(
epm_remove
) # to_executorch modifies the edge_program, so we make a copy
# Run pass with no removal
etpm_remove = epm_remove.to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=True,
memory_planning_pass=MemoryPlanningPass(
"greedy", alloc_graph_input=False
),
),
)
# Run pass with removal
etpm_no_remove = epm_no_remove.to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=True,
memory_planning_pass=MemoryPlanningPass(
"greedy", alloc_graph_input=False
),
),
)
out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()(
*example_inputs
)
out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()(
*example_inputs
)
self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5))
self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6))
def test_spec(self) -> None:
model = TestModel1()
model.eval()
example_inputs = model.get_example_inputs()
ep = torch.export.export(model, example_inputs)
etpm = to_edge(ep).to_executorch(
config=ExecutorchBackendConfig(
remove_view_copy=True,
memory_planning_pass=MemoryPlanningPass(
"greedy", alloc_graph_input=False
),
),
)
# etpm.exported_program().graph.print_tabular()
# idx opcode name target args kwargs
# --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- ----------------
# 0 placeholder p_parameter p_parameter () {}
# 1 placeholder p_parameter2 p_parameter2 () {}
# 2 placeholder x x () {}
# 3 call_function aten_view_copy_default <function view at 0x7fe57bea6d40> (p_parameter, [6, 5]) {}
# 4 call_function aten_view_copy_default_1 <function view at 0x7fe57bea6d40> (x, [6, 5]) {}
# 5 call_function alloc <function alloc at 0x7fe57bea6c20> (((6, 5), torch.float32),) {}
# 6 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc}
# 7 call_function aten_view_copy_default_2 <function view at 0x7fe57bea6d40> (aten_mul_tensor, [30]) {}
# 8 call_function alloc_1 <function alloc at 0x7fe57bea6c20> (((30,), torch.float32),) {}
# 9 call_function aten_mul_tensor_1 aten.mul.out (aten_view_copy_default_2, p_parameter2) {'out': alloc_1}
# 10 call_function alloc_2 <function alloc at 0x7fe57bea6c20> (((6, 5), torch.float32),) {}
# 11 call_function aten_view_copy_default_3 aten.view_copy.out (aten_mul_tensor_1, [6, 5]) {'out': alloc_2}
# 12 output output_1 output ((aten_view_copy_default_3,),) {}
for node in etpm.exported_program().graph.nodes:
if node.name == "p_parameter":
# p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 6
self.assertEqual(node.meta["spec"].lifetime, [0, 6])
elif node.name == "aten_view_copy_default":
# aten_view_copy_default is a memory.view of p_parameter.
# p_parameter is a constant with storage, so we check that the view's storage matches the base
# assert base is p_parameter
self.assertEqual(node.args[0].name, "p_parameter")
# assert base is const with storage
self.assertTrue(node.args[0].meta["spec"].const)
self.assertTrue(node.args[0].meta["spec"].storage is not None)
self.assertTrue(node.args[0].meta["spec"].mem_id is None)
self.assertTrue(node.args[0].meta["spec"].mem_offset is None)
# assert self is const with storage
self.assertTrue(node.meta["spec"].const)
self.assertTrue(node.meta["spec"].storage is not None)
self.assertTrue(node.meta["spec"].mem_id is None)
self.assertTrue(node.meta["spec"].mem_offset is None)
# assert storage matches
self.assertEqual(
node.meta["spec"].storage, node.args[0].meta["spec"].storage
)
# assert lifetime matches
self.assertEqual(
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
)
elif node.name == "aten_mul_tensor":
# aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 9
self.assertEqual(node.meta["spec"].lifetime, [5, 9])
elif node.name == "aten_view_copy_default_2":
# aten_view_copy_default_2 is a memory.view of aten_mul_tensor
# assert base is aten_mul_tensor
self.assertEqual(node.args[0].name, "aten_mul_tensor")
# assert base and self are not const, do not have storage,
# but do have mem_id and mem_offset
self.assertFalse(node.args[0].meta["spec"].const)
self.assertTrue(node.args[0].meta["spec"].storage is None)
self.assertTrue(node.args[0].meta["spec"].mem_id is not None)
self.assertTrue(node.args[0].meta["spec"].mem_offset is not None)
self.assertFalse(node.meta["spec"].const)
self.assertTrue(node.meta["spec"].storage is None)
self.assertTrue(node.meta["spec"].mem_id is not None)
self.assertTrue(node.meta["spec"].mem_offset is not None)
# assert self and base mem_id, mem_offset, and lifetime matches
self.assertEqual(
node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id
)
self.assertEqual(
node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset
)
self.assertEqual(
node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime
)
# Test evalues in execution plan
plan = etpm.executorch_program.execution_plan[0]
self.assertEqual(plan.operators[0].name, "executorch_prim::et_view")
self.assertEqual(plan.operators[1].name, "aten::mul")
self.assertEqual(plan.operators[2].name, "aten::view_copy")
instructions = plan.chains[0].instructions
self.assertEqual(len(instructions), 7)
self.assertEqual(
instructions[0].instr_args.op_index, 0 # pyre-ignore
) # view @ idx2
self.assertEqual(
instructions[1].instr_args.op_index, 0 # pyre-ignore
) # view @ idx3
self.assertEqual(
instructions[2].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx6
self.assertEqual(
instructions[3].instr_args.op_index, 0 # pyre-ignore
) # view @ idx7
self.assertEqual(
instructions[4].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx9
self.assertEqual(
instructions[5].instr_args.op_index, 2 # pyre-ignore
) # aten:view_copy @ idx11
self.assertEqual(
instructions[6].instr_args.op_index, 2 # pyre-ignore
) # aten:view_copy @ idx11