blob: 01c2f2b29a2d93da79c5bc33478661b32b9e9a00 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from unittest import TestCase
from executorch import exir
from executorch.exir import to_edge
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
from executorch.exir.tests.models import Repeat, TensorItem
from torch.export import export
class TestDynamicShapeProp(TestCase):
def test_repeat(self):
eager_model = Repeat()
inputs = eager_model.get_random_inputs()
inputs = inputs[0], inputs[1]
prog = to_edge(
export(eager_model, inputs, dynamic_shapes=eager_model.get_dynamic_shape()),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
gm = new_prog.exported_program().graph_module
DebugPass(show_spec=True)(gm)
*_, return_node = gm.graph.nodes
speclist = return_node.meta["spec"]
self.assertEqual(len(speclist), 2)
first_spec, second_spec = speclist
self.assertTrue(first_spec.is_upper_bound_tensor)
self.assertTrue(second_spec.is_upper_bound_tensor)
self.assertEqual(first_spec.shape, [4, 5])
class TestUnbackedSymInt(TestCase):
def test_unbacked_symint(self):
eager_model = TensorItem()
inputs = eager_model.get_random_inputs()
inputs = inputs[0], inputs[1]
prog = to_edge(
export(eager_model, inputs, dynamic_shapes=None),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
gm = new_prog.exported_program().graph_module
DebugPass(show_spec=True)(gm)
*_, return_node = gm.graph.nodes
speclist = return_node.meta["spec"]
self.assertEqual(len(speclist), 1)
self.assertTrue(speclist[0].is_upper_bound_tensor)
self.assertEqual(
speclist[0].shape, [100, 100]
) # upper bound of TensorItem model