| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # Example script for exporting models to flatbuffer with the Vulkan delegate |
| |
| # pyre-unsafe |
| |
| import argparse |
| import logging |
| import os |
| |
| import executorch.backends.vulkan.test.utils as test_utils |
| import torch |
| import torchvision |
| from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner |
| from executorch.devtools import BundledProgram |
| from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite |
| from executorch.devtools.bundled_program.serialize import ( |
| serialize_from_bundled_program_to_flatbuffer, |
| ) |
| from executorch.examples.models import MODEL_NAME_TO_MODEL |
| from executorch.examples.models.model_factory import EagerModelFactory |
| from executorch.exir import to_edge_transform_and_lower |
| from executorch.extension.export_util.utils import save_pte_program |
| from executorch.extension.pytree import tree_flatten |
| from torch.export import Dim, export |
| |
| FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| logging.basicConfig(level=logging.INFO, format=FORMAT) |
| |
| import urllib |
| |
| |
| def is_vision_model(model_name): |
| if model_name in [ |
| # These models are also registered in examples/models |
| "dl3", |
| "edsr", |
| "mv2", |
| "mv3", |
| "vit", |
| "ic3", |
| "ic4", |
| "resnet18", |
| "resnet50", |
| # These models are not registered in examples/models but are available via |
| # torchvision |
| "convnext_small", |
| "densenet161", |
| "shufflenet_v2_x1_0", |
| ]: |
| return True |
| |
| return False |
| |
| |
| def get_vision_model_sample_input(): |
| return (torch.randn(1, 3, 224, 224),) |
| |
| |
| def get_vision_model_dynamic_shapes(): |
| return ( |
| { |
| 2: Dim("height", min=1, max=16) * 16, |
| 3: Dim("width", min=1, max=16) * 16, |
| }, |
| ) |
| |
| |
| def get_dog_image_tensor(image_size=224, normalization="imagenet"): |
| url, filename = ( |
| "https://github.com/pytorch/hub/raw/master/images/dog.jpg", |
| "dog.jpg", |
| ) |
| try: |
| urllib.URLopener().retrieve(url, filename) |
| except: |
| urllib.request.urlretrieve(url, filename) |
| |
| from PIL import Image |
| from torchvision import transforms |
| |
| input_image = Image.open(filename).convert("RGB") |
| |
| transforms_list = [ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| ] |
| if normalization == "imagenet": |
| transforms_list.append( |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ) |
| |
| preprocess = transforms.Compose(transforms_list) |
| |
| input_tensor = preprocess(input_image) |
| input_batch = input_tensor.unsqueeze(0) |
| input_batch = (input_batch,) |
| return input_batch |
| |
| |
| def init_model(model_name): |
| if model_name == "convnext_small": |
| return torchvision.models.convnext_small() |
| if model_name == "densenet161": |
| return torchvision.models.densenet161() |
| if model_name == "shufflenet_v2_x1_0": |
| return torchvision.models.shufflenet_v2_x1_0() |
| if model_name == "YOLO_NAS_S": |
| try: |
| from super_gradients.common.object_names import Models |
| from super_gradients.training import models |
| except ImportError: |
| raise ImportError( |
| "Please install super-gradients to use the YOLO_NAS_S model." |
| ) |
| |
| return models.get(Models.YOLO_NAS_S, pretrained_weights="coco") |
| |
| return None |
| |
| |
| def get_sample_inputs(model_name): |
| # Lock the random seed for reproducibility |
| torch.manual_seed(42) |
| |
| if is_vision_model(model_name): |
| return get_vision_model_sample_input() |
| if model_name == "YOLO_NAS_S": |
| input_batch = get_dog_image_tensor(640) |
| return input_batch |
| |
| return None |
| |
| |
| def get_dynamic_shapes(model_name): |
| if is_vision_model(model_name): |
| return get_vision_model_dynamic_shapes() |
| |
| return None |
| |
| |
| def main() -> None: # noqa: C901 |
| logger = logging.getLogger("") |
| logger.setLevel(logging.INFO) |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-m", |
| "--model_name", |
| required=True, |
| help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", |
| ) |
| |
| parser.add_argument( |
| "-fp16", |
| "--force_fp16", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Force fp32 tensors to be converted to fp16 internally. Input/s outputs " |
| "will be converted to/from fp32 when entering/exiting the delegate. Default is " |
| "False", |
| ) |
| |
| parser.add_argument( |
| "--small_texture_limits", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="sets the default texture limit to be (2048, 2048, 2048) which is " |
| "compatible with more devices (i.e. desktop/laptop GPUs) compared to the " |
| "default (16384, 16384, 2048) which is more targeted for mobile GPUs. Default " |
| "is False.", |
| ) |
| |
| parser.add_argument( |
| "--skip_memory_planning", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Skips memory planning pass while lowering, which can be used for " |
| "debugging. Default is False.", |
| ) |
| |
| parser.add_argument( |
| "-s", |
| "--strict", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="whether to export with strict mode. Default is True", |
| ) |
| |
| parser.add_argument( |
| "-d", |
| "--dynamic", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Enable dynamic shape support. Default is False", |
| ) |
| |
| parser.add_argument( |
| "-r", |
| "--etrecord", |
| required=False, |
| default="", |
| help="Generate and save an ETRecord to the given file location", |
| ) |
| |
| parser.add_argument("-o", "--output_dir", default=".", help="output directory") |
| |
| parser.add_argument( |
| "-b", |
| "--bundled", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Export as bundled program (.bpte) instead of regular program (.pte). Default is False", |
| ) |
| |
| parser.add_argument( |
| "-t", |
| "--test", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Execute lower_module_and_test_output to validate the model. Default is False", |
| ) |
| |
| parser.add_argument( |
| "--save_inputs", |
| action=argparse.BooleanOptionalAction, |
| default=False, |
| help="Whether to save the inputs to the model. Default is False", |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.model_name in MODEL_NAME_TO_MODEL: |
| model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( |
| *MODEL_NAME_TO_MODEL[args.model_name] |
| ) |
| else: |
| model = init_model(args.model_name) |
| example_inputs = get_sample_inputs(args.model_name) |
| dynamic_shapes = get_dynamic_shapes(args.model_name) if args.dynamic else None |
| |
| if model is None: |
| raise RuntimeError( |
| f"Model {args.model_name} is not a valid name. " |
| f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." |
| ) |
| |
| # Prepare model |
| model.eval() |
| |
| # Setup compile options |
| compile_options = {} |
| if args.dynamic: |
| compile_options["require_dynamic_shapes"] = True |
| # Try to manually get the dynamic shapes for the model if not set |
| if dynamic_shapes is None: |
| dynamic_shapes = get_dynamic_shapes(args.model_name) |
| |
| if args.force_fp16: |
| compile_options["force_fp16"] = True |
| if args.skip_memory_planning: |
| compile_options["skip_memory_planning"] = True |
| if args.small_texture_limits: |
| compile_options["small_texture_limits"] = True |
| |
| logging.info(f"Exporting model {args.model_name} with Vulkan delegate") |
| |
| # Export the model using torch.export |
| if dynamic_shapes is not None: |
| program = export( |
| model, example_inputs, dynamic_shapes=dynamic_shapes, strict=args.strict |
| ) |
| else: |
| program = export(model, example_inputs, strict=args.strict) |
| |
| # Transform and lower with Vulkan partitioner |
| edge_program = to_edge_transform_and_lower( |
| program, |
| partitioner=[VulkanPartitioner(compile_options)], |
| generate_etrecord=args.etrecord, |
| ) |
| |
| logging.info( |
| f"Exported and lowered graph:\n{edge_program.exported_program().graph}" |
| ) |
| |
| # Create executorch program |
| exec_prog = edge_program.to_executorch() |
| |
| # Save ETRecord if requested |
| if args.etrecord: |
| exec_prog.get_etrecord().save(args.etrecord) |
| logging.info(f"Saved ETRecord to {args.etrecord}") |
| |
| # Save the program |
| output_filename = f"{args.model_name}_vulkan" |
| |
| atol = 1e-4 |
| rtol = 1e-4 |
| |
| # If forcing fp16, then numerical divergence is expected |
| if args.force_fp16: |
| atol = 2e-2 |
| rtol = 1e-1 |
| |
| # Save regular program |
| save_pte_program(exec_prog, output_filename, args.output_dir) |
| logging.info( |
| f"Model exported and saved as {output_filename}.pte in {args.output_dir}" |
| ) |
| |
| if args.save_inputs: |
| inputs_flattened, _ = tree_flatten(example_inputs) |
| for i, input_tensor in enumerate(inputs_flattened): |
| input_filename = os.path.join(args.output_dir, f"input{i}.bin") |
| input_tensor.numpy().tofile(input_filename) |
| f"Model input saved as {input_filename} in {args.output_dir}" |
| |
| if args.bundled: |
| # Create bundled program |
| logging.info("Creating bundled program with test cases") |
| |
| # Generate expected outputs by running the model |
| expected_outputs = [model(*example_inputs)] |
| |
| # Flatten sample inputs to match expected format |
| inputs_flattened, _ = tree_flatten(example_inputs) |
| |
| # Create test suite with the sample inputs and expected outputs |
| test_suites = [ |
| MethodTestSuite( |
| method_name="forward", |
| test_cases=[ |
| MethodTestCase( |
| inputs=inputs_flattened, |
| expected_outputs=expected_outputs, |
| ) |
| ], |
| ) |
| ] |
| |
| # Create bundled program |
| bp = BundledProgram(exec_prog, test_suites) |
| |
| # Serialize to flatbuffer |
| bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp) |
| |
| # Save bundled program |
| bundled_output_path = f"{args.output_dir}/{output_filename}.bpte" |
| with open(bundled_output_path, "wb") as file: |
| file.write(bp_buffer) |
| |
| logging.info( |
| f"Bundled program exported and saved as {output_filename}.bpte in {args.output_dir}" |
| ) |
| |
| # Test the model if --test flag is provided |
| if args.test: |
| test_result = test_utils.run_and_check_output( |
| reference_model=model, |
| executorch_program=exec_prog, |
| sample_inputs=example_inputs, |
| atol=atol, |
| rtol=rtol, |
| ) |
| |
| if test_result: |
| logging.info( |
| "✓ Model test PASSED - outputs match reference within tolerance" |
| ) |
| else: |
| logging.error("✗ Model test FAILED - outputs do not match reference") |
| raise RuntimeError( |
| "Model validation failed: ExecuTorch outputs do not match reference model outputs" |
| ) |
| |
| |
| if __name__ == "__main__": |
| with torch.no_grad(): |
| main() # pragma: no cover |