blob: 2c2e52c744f67c64191bf86ea3c0a0ae8831cbe9 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Providing builders for LLM models. These builders help user to build LLM
# eager models, apply source transformations and quantization and export them to
# ExecuTorch.
import logging
from enum import Enum
from typing import Any, Callable, List, Optional
import torch
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from executorch.exir import EdgeProgramManager
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.nn.attention import SDPBackend
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
class DType(Enum):
fp32 = "fp32"
fp16 = "fp16"
bf16 = "bf16"
def to_torch_dtype(self) -> torch.dtype:
mapping = {
DType.fp32: torch.float32,
DType.fp16: torch.float16,
}
if self not in mapping:
raise ValueError(f"Unsupported dtype {self}")
return mapping[self]
class LLMEdgeManager:
"""
Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
"""
def __init__(
self,
model,
modelname,
max_seq_len,
dtype,
use_kv_cache,
example_inputs,
args: Optional[Any] = None,
enable_dynamic_shape: bool = False,
verbose: bool = False,
metadata: Optional[dict] = None,
dynamic_shapes: Optional[Any] = None,
):
self.model = model
# graph module returned from capture_pre_autograd_graph
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
self.modelname = modelname
self.max_seq_len = max_seq_len
self.dtype = dtype
self.example_inputs = example_inputs
self.use_kv_cache = use_kv_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.verbose = verbose
self.metadata = metadata
self.applied_source_transforms = []
self.edge_manager: Optional[EdgeProgramManager] = None
self.export_program = None
self.output_dir = "."
self.dynamic_shapes = dynamic_shapes
self._saved_pte_filename = None
self.args = args
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Set the directory where the .pte file will be saved.
Args:
output_dir (str): The directory to store the .pte file.
"""
self.output_dir = output_dir
return self
def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
"""
Convert the model to the specified dtype.
Args:
dtype_override (Optional[DType]): Override the dtype of the model.
"""
assert not dtype_override or isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
if dtype_override is not None and dtype_override != self.dtype:
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
self.model = self.model.to(dtype=torch_dtype)
self.dtype = dtype_override
return self
def source_transform(
self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
) -> "LLMEdgeManager":
"""
Apply source transforms to the model. The transforms are callables that
takes nn.Module as input and returns nn.Module.
Args:
transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A
list of source transforms.
"""
for transform in transforms:
self.model = transform(self.model)
self.applied_source_transforms.extend(transforms)
if self.verbose:
logging.info(f"Applied source transforms: {self.applied_source_transforms}")
return self
def _get_dynamic_shape(self) -> Any:
if self.dynamic_shapes:
return self.dynamic_shapes
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
if not self.use_kv_cache:
# Only one input argument: tokens
self.dynamic_shapes = ({1: dim},)
elif self.enable_dynamic_shape:
# Two input arguments: tokens and input_pos but input_pos is static shape
self.dynamic_shapes = ({1: dim}, {0: 1})
else:
# Two input arguments: tokens and input_pos but both are of static shape
self.dynamic_shapes = None
return self.dynamic_shapes
def _get_edge_config(self) -> EdgeCompileConfig:
edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=bool(self.dtype == DType.fp16),
_skip_dim_order=True,
)
return edge_config
def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
dynamic_shape = self._get_dynamic_shape()
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
# pyre-fixme[8]
if hasattr(self.args, "qnn") and self.args.qnn:
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
self.pre_autograd_graph_module = torch.export.export(
self.model,
self.example_inputs,
dynamic_shapes=dynamic_shape,
strict=True,
).module()
else:
self.pre_autograd_graph_module = capture_pre_autograd_graph(
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
)
return self
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
Args:
quantizers (Optional[List[Quantizer]]): A list of quantizers.
"""
assert (
self.edge_manager is None
), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
logging.info(f"Using pt2e {quantizers} to quantizing the model...")
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
if quantizers:
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
assert (
self.pre_autograd_graph_module is not None
), "Please run capture_pre_autograd_graph first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
# Calibrate
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self
else:
logging.info("No quantizer provided, passing...")
return self
def export_to_edge(self) -> "LLMEdgeManager":
"""
Export the model to Edge dialect and retrieve a LLMEdgeManager.
"""
dynamic_shape = self._get_dynamic_shape()
edge_config = self._get_edge_config()
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.pre_autograd_graph_module is None:
# Run capture_pre_autograd_graph if it didn't run
self.capture_pre_autograd_graph()
self.edge_manager = export_to_edge(
self.pre_autograd_graph_module, # pyre-fixme[6]
self.example_inputs,
dynamic_shapes=dynamic_shape,
edge_constant_methods=self.metadata,
edge_compile_config=edge_config,
verbose=self.verbose,
)
return self
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
"""
Partition the model and lower to different backends. The signature is
aligned with the signature of `to_backend` method of EdgeManager.
Args:
partitioners (Optional[List[Partitioner]]): One or more
partitioner to be sent to EdgeManager.to_backend().
"""
if partitioners is None:
logging.info("No partitioner provided, passing...")
else:
for partitioner in partitioners:
if partitioner is not None:
assert (
self.edge_manager is not None
), "Need to run export_to_edge() first"
self.edge_manager = self.edge_manager.to_backend(partitioner)
if self.verbose:
logging.info(
format_delegated_graph(
self.edge_manager.exported_program().graph_module
)
)
logging.info(f"Applied partitioners: {partitioner}")
else:
logging.info("No partitioner provided, passing...")
continue
return self
def to_executorch(self) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
QuantFusionPass(),
],
memory_planning_pass=MemoryPlanningPass(
"greedy", alloc_graph_input=False
),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
logging.info(
"Required memory for activation in bytes: {}".format(
self.export_program._emitter_output.program.execution_plan[
0
].non_const_buffer_sizes
),
)
return self
def save_to_pte(self, output_name: str) -> None:
"""
Save the model to a .pte file.
Args:
output_name (Optional[str]): The name of the .pte file.
"""
assert output_name, "Need a valid output name"
filename = save_pte_program(self.export_program, output_name, self.output_dir)
self._saved_pte_filename = filename
def get_saved_pte_filename(self) -> Optional[str]:
"""
Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
"""
return self._saved_pte_filename