| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # flake8: noqa: F401 |
| import functools |
| import inspect |
| import os |
| import random |
| import unittest |
| from typing import Callable, Dict, Optional, Tuple, Type |
| from unittest import skip, skipUnless |
| |
| import executorch.exir as exir |
| |
| import executorch.exir.control_flow as control_flow |
| |
| # @manual=//executorch/extension/pytree:pybindings |
| import executorch.extension.pytree as pytree |
| import torch |
| |
| from executorch.exir import ( |
| CaptureConfig, |
| EdgeCompileConfig, |
| ExecutorchBackendConfig, |
| memory, |
| ) |
| from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode |
| from executorch.exir.emit import emit_program |
| from executorch.exir.pass_manager import PassManager |
| from executorch.exir.passes import ( |
| DebugPass, |
| MemoryPlanningPass, |
| to_scratch_op_pass, |
| ToOutVarPass, |
| ) |
| from executorch.exir.print_program import pretty_print, print_program |
| from executorch.exir.tensor import make_tensor_value, TensorSpec |
| from executorch.exir.tests.control_flow_models import ( |
| FTCondBasic, |
| FTCondDynShape, |
| FTMapBasic, |
| FTMapDynShape, |
| ) |
| from executorch.exir.tests.dynamic_shape_models import BatchNormModel |
| |
| from executorch.exir.tests.transformer import Transformer |
| from functorch.experimental.control_flow import cond |
| |
| kernel_mode = None # either aten mode or lean mode |
| try: |
| from executorch.extension.pybindings.portable_lib import ( |
| _load_bundled_program_from_buffer, |
| _load_for_executorch_from_buffer, |
| _load_for_executorch_from_bundled_program, |
| ) |
| |
| kernel_mode = "lean" |
| except ImportError as e: |
| print(e) |
| pass |
| |
| try: |
| from executorch.extension.pybindings.aten_lib import ( |
| _load_bundled_program_from_buffer, |
| _load_for_executorch_from_buffer, |
| _load_for_executorch_from_bundled_program, |
| ) |
| |
| assert kernel_mode is None |
| kernel_mode = "aten" |
| except ImportError as e: |
| print(e) |
| pass |
| |
| assert kernel_mode is not None |
| |
| is_aten_mode = kernel_mode == "aten" |
| is_lean_mode = kernel_mode == "lean" |
| |
| from torch import nn |
| from torch.utils import _pytree as torch_pytree |
| |
| from .exported_module import ExportedModule |
| |
| |
| RUN_SKIPPED = int(os.environ.get("RUN_SKIPPED", "0")) |
| |
| |
| class ModuleBasic(nn.Module): |
| def __init__(self): |
| super(ModuleBasic, self).__init__() |
| |
| def forward(self, x): |
| return torch.sin(x).max() |
| |
| def get_random_inputs(self): |
| return (torch.randn(100),) |
| |
| |
| class ModuleOpsReturnMulti(nn.Module): |
| def __init__(self): |
| super(ModuleOpsReturnMulti, self).__init__() |
| |
| def forward(self, a, b): |
| x, y = torch.topk(a, 3) |
| return x * 2 + b |
| |
| def get_random_inputs(self): |
| return (torch.randn(10), torch.randn(3)) |
| |
| |
| class ModuleAdd(nn.Module): |
| def __init__(self): |
| super(ModuleAdd, self).__init__() |
| |
| def forward(self, x, y): |
| return torch.add(x, y) |
| |
| def get_random_inputs(self): |
| return (torch.randn(2, 2), torch.randn(2, 2)) |
| |
| |
| class ModuleFloatAddWithAlpha(nn.Module): |
| def __init__(self): |
| super(ModuleFloatAddWithAlpha, self).__init__() |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor, c: float): |
| return torch.add(x, y, alpha=c) |
| |
| def get_random_inputs(self): |
| return (torch.randn(2, 2), torch.randn(2, 2), random.random()) |
| |
| |
| class ModuleIntAddWithAlpha(nn.Module): |
| def __init__(self): |
| super(ModuleIntAddWithAlpha, self).__init__() |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor, c: int): |
| return torch.add(x, y, alpha=c) |
| |
| def get_random_inputs(self): |
| return ( |
| torch.randint(0, 10, (2, 2)), |
| torch.randint(0, 10, (2, 2)), |
| random.randint(0, 10), |
| ) |
| |
| |
| class ModuleContainers(nn.Module): |
| def __init__(self): |
| super(ModuleContainers, self).__init__() |
| |
| def forward(self, d): |
| a = d["a"] |
| b = d["b"] |
| return {"inputs": (a, b), "c": torch.add(a, b)} |
| |
| def get_random_inputs(self): |
| return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) |
| |
| |
| class ToyModelForMemPlanning(nn.Module): |
| def __init__(self): |
| super(ToyModelForMemPlanning, self).__init__() |
| |
| def forward(self, a, b): |
| o = a |
| for i in range(3): |
| o = o * a |
| o = o + b |
| return o |
| |
| def get_random_inputs(self): |
| return ( |
| torch.randn(10), |
| torch.randn(10), |
| ) |
| |
| |
| class MemPlanningWithScratchTensor(nn.Module): |
| def __init__(self): |
| super(MemPlanningWithScratchTensor, self).__init__() |
| self.linear1 = nn.Linear(4, 2) |
| self.linear2 = nn.Linear(4, 2) |
| |
| def forward(self, a, b): |
| o1 = self.linear1(a) |
| o2 = self.linear2(b) |
| return o1 + o2 |
| |
| def get_random_inputs(self): |
| return ( |
| torch.randn(10, 4), |
| torch.randn(10, 4), |
| ) |
| |
| |
| class ModuleOpsReturnTensorList(nn.Module): |
| def __init__(self): |
| super(ModuleOpsReturnTensorList, self).__init__() |
| |
| def forward(self, x): |
| split = torch.ops.aten.tensor_split.sections(x, 3) |
| return split[0] |
| |
| def get_random_inputs(self): |
| return (torch.randn(100),) |
| |
| |
| class ModuleReturnInput(nn.Module): |
| def __init__(self): |
| super(ModuleReturnInput, self).__init__() |
| |
| def forward(self, x): |
| return (x, x, {"x": x, "y": x}, [x, x, x]) |
| |
| def get_random_inputs(self): |
| return (torch.randn(1),) |
| |
| |
| class ModuleIfElse(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, c, x): |
| x = x * x |
| |
| def addloop(x, n): |
| out = x |
| for _ in range(n - 1): |
| out = out + x |
| return out |
| |
| def true_branch(c, x): |
| return addloop(x, 3) |
| |
| def false_branch(c, x): |
| return addloop(x, 4) |
| |
| y = cond(c, true_branch, false_branch, (c, x)) |
| return y * y |
| |
| def get_random_inputs(self): |
| return (torch.randint(2, [1]) == 0, torch.randn(10)) |
| |
| |
| class ModuleIfElseWithBoolInput(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, c: bool, x: torch.Tensor): |
| x = x * x |
| |
| def addloop(x, n): |
| out = x |
| for _ in range(n - 1): |
| out = out + x |
| return out |
| |
| def true_branch(c, x): |
| return addloop(x, 3) |
| |
| def false_branch(c, x): |
| return addloop(x, 4) |
| |
| y = cond(c, true_branch, false_branch, (c, x)) |
| |
| return y * y |
| |
| def get_random_inputs(self): |
| return (random.randint(0, 1) == 0, torch.randn(10)) |
| |
| |
| class ModuleWhileIf(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, accum, cnt): |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| ) |
| def loop_cond(accum, cnt): |
| return cnt != torch.zeros([1]).to(dtype=torch.long) |
| |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| ) |
| def loop_body(accum, cnt): |
| # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long),) |
| ) |
| def true_branch(cnt): |
| return cnt |
| |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long),) |
| ) |
| def false_branch(cnt): |
| return torch.zeros([1], dtype=torch.long) |
| |
| accum = accum + cond( |
| torch.BoolTensor([True]), true_branch, false_branch, (cnt,) |
| ) |
| # 'cnt - 1' does not work yet since the runtime does not expect |
| # tensor to be mixed with scalar for sub op. |
| return accum, cnt - torch.ones([1]).to(dtype=torch.long) |
| |
| y, _ = control_flow.while_loop( |
| loop_cond, |
| loop_body, |
| (accum, cnt), |
| ) |
| return y |
| |
| def get_random_inputs(self): |
| return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| |
| |
| class ModuleIfWhile(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, accum, cnt): |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| ) |
| def true_branch(accum, cnt): |
| @control_flow.tracing_context( |
| inputs=( |
| torch.zeros([1]).to(dtype=torch.long), |
| torch.randint(10, 100, [1]), |
| ) |
| ) |
| def loop_cond(accum, cnt): |
| return cnt != torch.zeros([1]).to(dtype=torch.long) |
| |
| @control_flow.tracing_context( |
| inputs=( |
| torch.zeros([1]).to(dtype=torch.long), |
| torch.randint(10, 100, [1]), |
| ) |
| ) |
| def loop_body(accum, cnt): |
| return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long) |
| |
| return control_flow.while_loop(loop_cond, loop_body, (accum, cnt)) |
| |
| @control_flow.tracing_context( |
| inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| ) |
| def false_branch(accum, cnt): |
| return accum, cnt |
| |
| return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[ |
| 0 |
| ] |
| |
| def get_random_inputs(self): |
| return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1])) |
| |
| |
| class ModuleContiguousTensor(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(8, 32) |
| |
| def forward(self, arg): |
| return self.linear(arg) |
| |
| def get_random_inputs(self): |
| return (torch.randn(3, 8),) |
| |
| |
| class ModuleInputDynamicShape(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| for i in range(4): |
| x = x + x |
| x = x * x |
| return x |
| |
| def get_upper_bound_inputs(self): |
| return (torch.randn(10),) |
| |
| def get_random_inputs(self): |
| n = random.randint(1, 10) |
| return (torch.randn(n),) |
| |
| |
| class ModuleIntermediateDynamicShape(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| x = x * x |
| |
| # We should use x[torch.nonzero(x)] ideally, but index op is not supported |
| # in the runtime so far. |
| x = torch.nonzero(x) |
| return x + x |
| |
| def get_random_inputs(self): |
| return (torch.randint(0, 2, (10,), dtype=torch.float),) |
| |
| |
| def allclose(lhs, rhs, rtol=1e-5, atol=1e-8): |
| r""" |
| Unlike torch.allocse which only handles Tensor arguments, allclose handles |
| list, tuple, dict and nesting of these as well. |
| """ |
| if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): |
| return torch.allclose(lhs, rhs, rtol, atol) |
| if isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)): |
| return len(lhs) == len(rhs) and all( |
| allclose(a, b, rtol, atol) for a, b in zip(lhs, rhs) |
| ) |
| if isinstance(lhs, dict) and isinstance(rhs, dict): |
| lhs_keys = set(lhs.keys()) |
| rhs_keys = set(rhs.keys()) |
| if lhs_keys != rhs_keys: |
| return False |
| return all(allclose(lhs[k], rhs[k], rtol, atol) for k in lhs) |
| else: |
| raise RuntimeError( |
| f"Unexpected types: lhs type {type(lhs)}, rhs type {type(rhs)}" |
| ) |
| |
| |
| def validate_contiguous_tensors(program): |
| def _is_contiguous_tensor(tensor: exir.schema.Tensor): |
| """ |
| Ensure the tensor is pytorch contigous (torch.memory_format=torch.contiguous) |
| since the runtime can not handle non-contiguous tensors so far. |
| """ |
| sizes = tensor.sizes |
| dim_order = tensor.dim_order |
| assert len(sizes) == len(dim_order) |
| for i, val in enumerate(dim_order): |
| if i != val: |
| return False |
| return True |
| |
| for execution_plan in program.execution_plan: |
| for value in execution_plan.values: |
| if isinstance(value.val, exir.schema.Tensor): |
| assert _is_contiguous_tensor( |
| value.val |
| ), f"Non-contiguous tensor found: size {value.val.sizes} stride {value.val.strides}. constant_buffer_idx {value.val.constant_buffer_idx}. allocation_info {value.val.allocation_info}." |
| |
| |
| class BoundMethod(object): |
| def __init__(self, instance, callable): |
| self._instance = instance |
| self._callable = callable |
| |
| def __call__(self, *args, **kwargs): |
| return self._callable(self.instance, *args, **kwargs) |
| |
| |
| def maketest( |
| module_cls: Type[nn.Module], |
| niter: int = 10, |
| run_executor: bool = True, |
| do_tree_flatten: bool = False, |
| run_graph_module: bool = True, |
| atol: float = 1e-8, |
| rtol: float = 1e-5, |
| ignore_to_out_var_failure: bool = False, |
| allow_non_contiguous_tensor: bool = False, |
| method: str = "forward", |
| dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, |
| capture_config=None, |
| verify_graph: Optional[Callable] = None, |
| ) -> Callable[[unittest.TestCase], None]: |
| r"""Returns a TestCase method to test the provided module class and method. |
| |
| Args: |
| module_cls: The subclass of nn.Module to export. |
| niter: The number of random input data sets to test with. |
| run_executor: Whether to run the model on the executor. We may want to |
| skip running a model thru executor since some kernels are not |
| implemented. |
| do_tree_flatten: Whether to flatten input and unflatten output. |
| run_graph_module: Whether to run the traced and transformed GraphModule. |
| One may want to skip this if some custom ops do not have |
| implementation in torch.ops but is implemented in the executor. |
| atol: Absolute tolerance used in allclose and torch.allclose |
| rtol: Relative tolerance used in allclose and torch.allclose |
| ignore_to_out_var_failure: Whether to ignore the failue when a |
| functional op does not have an out variant. |
| allow_non_contiguous_tensor: If false, will validate that the emitted |
| program only contains contiguous tensors. |
| method: The name of the module_cls method to trace. |
| dynamic_memory_planning_mode: The dynamic memory planning mode to use. |
| |
| Returns: |
| A TestCase method that tests the provided module class and method. |
| """ |
| |
| def wrapper(self: unittest.TestCase) -> None: |
| """A TestCase method that traces/exports/tests an nn.Module and method.""" |
| module = ExportedModule.export( |
| module_class=module_cls, |
| # testend2end only supports modules with single methods defined |
| methods=(method,), |
| ignore_to_out_var_failure=ignore_to_out_var_failure, |
| dynamic_memory_planning_mode=dynamic_memory_planning_mode, |
| capture_config=capture_config, |
| ) |
| if verify_graph: |
| verify_graph(self, module.exported_program.graph_module) |
| print(f"inputs for tracing: {module.trace_inputs}") |
| |
| # compare the result between the eager module and graph module |
| inputs_list = [module.get_random_inputs() for _ in range(niter)] |
| |
| if run_graph_module: |
| for inputs in inputs_list: |
| with torch.no_grad(): |
| # only one method is supported so just grab that single method |
| expected = getattr(module.eager_module, module.methods[0])(*inputs) |
| with torch.no_grad(): |
| result = module.exported_program.module()(*inputs) |
| self.assertTrue(allclose(expected, result, rtol, atol)) |
| |
| program = module.executorch_program.executorch_program |
| pretty_print(program) |
| print_program(program, show_meminfo=True, mark_dynamic_shape_tensor=True) |
| print(f"mem buffer sizes: {program.execution_plan[0].non_const_buffer_sizes}") |
| if not allow_non_contiguous_tensor: |
| validate_contiguous_tensors(program) |
| self.assertTrue(len(program.execution_plan[0].non_const_buffer_sizes) >= 2) |
| # We should not enable the following assertion since for some models |
| # that simply returning graph input, no mutable memory should be allocated |
| # self.assertTrue(all(s > 0 for s in program.program.execution_plan[0].non_const_buffer_sizes[1:])) |
| |
| program.version = 0 |
| buff = module.executorch_program.buffer |
| # Check that the magic version number is in the expected place, and |
| # follows the expected pattern. |
| self.assertRegex(buff[4:8].decode(errors="replace"), r"^ET[0-9][0-9]$") |
| |
| if run_executor: |
| print("Running on the runtime") |
| executorch_module = _load_for_executorch_from_buffer(buff) |
| # compare the result between eager module and executor |
| for idx, inputs in enumerate(inputs_list): |
| with torch.no_grad(): |
| expected = getattr(module.eager_module, method)(*inputs) |
| |
| if do_tree_flatten: |
| # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. |
| flatten_inputs, inputs_spec = pytree.tree_flatten(*inputs) |
| executorch_result = executorch_module.forward([*flatten_inputs]) |
| # pyre-fixme[16]: Module `pytree` has no attribute `TreeSpec`. |
| executorch_result_unflatten = pytree.TreeSpec.from_str( |
| program.execution_plan[0].container_meta_type.encoded_out_str |
| ).tree_unflatten(executorch_result) |
| actual = executorch_result_unflatten |
| else: |
| actual = executorch_module.forward(inputs)[0] |
| is_close = allclose(expected, actual, rtol, atol) |
| if not is_close: |
| print(f"Fail for {idx}th inputs: {inputs}") |
| print(f"expected result: {expected}") |
| print(f"actual result: {actual}") |
| self.assertTrue(is_close) |
| |
| return wrapper |
| |
| |
| class E2ETest(unittest.TestCase): |
| r""" |
| When adding a new unittest, call maketest(ModuleName) if possible since |
| maketest handles all the boilterplate part. Ideally, we only need define |
| a new nn.Module and add one line to call maketest for new end2end test cases. |
| """ |
| |
| # don't run the model thru executor because aten::sin.out is not defined |
| # in the executor currently. |
| # |
| # aten::max.default does not have an out variant. Thus we need set |
| # ignore_to_out_var_failure to be True. |
| def test_basic(self): |
| maketest(ModuleBasic, run_executor=False, ignore_to_out_var_failure=True)(self) |
| |
| # Make sure we can handle ops that return mutliple values. E.g. topk |
| # At one time we can not properly setup TensorSpec for an Fx node |
| # returning multiple tensors |
| # |
| # don't run the model thru executor because aten::topk.values is not defined |
| # in the executor currently |
| def test_ops_return_multi(self): |
| maketest(ModuleOpsReturnMulti, run_executor=False)(self) |
| |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") |
| def test_mem_planning_toy_model(self): |
| maketest( |
| ToyModelForMemPlanning, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| ), |
| )(self) |
| |
| # TODO: add ops implementations and turn on 'run_executor' |
| def test_mem_planning_scratch_tensor(self): |
| maketest( |
| MemPlanningWithScratchTensor, |
| run_graph_module=False, |
| run_executor=False, |
| atol=1e-5, |
| )(self) |
| |
| def test_executorch_forward(self): |
| maketest(ModuleAdd)(self) |
| |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") |
| def test_containers(self): |
| maketest( |
| ModuleContainers, |
| do_tree_flatten=True, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| ), |
| )(self) |
| |
| # can not run the graph module since the out variance with tensor list out |
| # argument returns None rather than tensor list. |
| # |
| # Can not run in the executor since kernel for tensor splitting is not implemented.. |
| def test_ops_return_tensorlist(self): |
| maketest(ModuleOpsReturnTensorList, run_graph_module=False, run_executor=False)( |
| self |
| ) |
| |
| # Failed to produce a graph during tracing w/ dynamo because there are no torch ops |
| # test_return_input = maketest(ModuleReturnInput, do_tree_flatten=True) |
| |
| # can not run this on the executor because missing the following ops: |
| # aten::select_copy.int_out, aten::eq.Scalar_out |
| # TODO(zhxchen17) re-enable these tests. |
| # test_control_flow_cond = maketest(ControlFlowCond, run_executor=False) |
| # fail to trace with functionalization enabled |
| # test_ifelse = maketest(ModuleIfElse) |
| |
| # fail to trace with functionalization enabled |
| # Fail with error: Missing out variants: {'aten::select', 'aten::_shape_as_tensor', 'aten::tensor_split'} |
| # TODO(zhxchen17) re-enable these tests. |
| # test_while_0 = maketest( |
| # ControlFlowWhile, |
| # ignore_to_out_var_failure=True, |
| # run_executor=False, |
| # ) |
| |
| # test_while = maketest(ModuleWhile) |
| |
| # test_while_if = maketest(ModuleWhileIf) |
| # test_if_while = maketest(ModuleIfWhile) |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") |
| def test_contiguous_tensor(self): |
| maketest(ModuleContiguousTensor, run_executor=False)(self) |
| |
| |
| class DynamicModelE2ETest(unittest.TestCase): |
| """ |
| End2end tests for dynamic models. For dynamic models we mean models with |
| control flow or dynamic shape. |
| """ |
| |
| @skip("Revisit when unbacked symint is ready") |
| def test_intermediate_dynamic_shape(self): |
| maketest( |
| ModuleIntermediateDynamicShape, |
| run_graph_module=False, |
| allow_non_contiguous_tensor=True, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| ), |
| )(self) |
| |
| # TODO(shunting): some non constant tensors for transformer are non-contiguous. |
| # Ignore for now. Will debug more. |
| # NOTE: can not run on runtime since missing these ops: P535190636 |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) This fails on OSS macos job") |
| def test_transformer_encode(self): |
| maketest( |
| Transformer, |
| method="encode", |
| allow_non_contiguous_tensor=True, |
| run_executor=False, |
| )(self) |
| |
| # basic test for functorch torch.ops.higher_order.cond |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") |
| def test_ft_cond_basic(self): |
| maketest( |
| FTCondBasic, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| enable_functionalization=False, # TODO enable functionalization |
| ), |
| )(self) |
| |
| @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") |
| def test_ft_map_basic(self): |
| maketest( |
| FTMapBasic, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| enable_functionalization=False, # TODO enable functionalization |
| ), |
| )(self) |
| |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") |
| def test_ft_cond_dynshape(self): |
| maketest( |
| FTCondDynShape, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| enable_functionalization=False, # TODO enable functionalization |
| ), |
| )(self) |
| |
| @skipUnless(RUN_SKIPPED, "Emitter is not ready yet") |
| def test_ft_map_dynshape(self): |
| maketest( |
| FTMapDynShape, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| enable_functionalization=False, # TODO enable functionalization |
| ), |
| )(self) |
| |
| @skipUnless(RUN_SKIPPED, "TODO(larryliu0820) Fix this in both fbcode and oss") |
| def test_batch_norm(self): |
| maketest( |
| BatchNormModel, |
| capture_config=exir.CaptureConfig( |
| enable_dynamic_shape=True, |
| ), |
| verify_graph=BatchNormModel.verify_graph, |
| # TODO: lean mode does not have native_batch_norm.out implemented |
| # run this on aten mode. |
| run_executor=is_aten_mode, |
| )(self) |