| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| |
| import argparse |
| |
| from typing import Optional, Union |
| |
| import torch |
| from executorch.examples.models.llama.export_llama_lib import ( |
| get_quantizer_and_quant_params, |
| ) |
| from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken |
| |
| from executorch.extension.llm.export.builder import LLMEdgeManager |
| from executorch.extension.llm.tokenizer.tokenizer import ( |
| Tokenizer as SentencePieceTokenizer, |
| ) |
| from executorch.extension.llm.tokenizer.utils import get_tokenizer |
| from lm_eval.evaluator import simple_evaluate |
| |
| from .evaluate.eager_eval import EagerEvalWrapper |
| |
| from .export_llama_lib import ( |
| _prepare_for_llama_export, |
| build_args_parser as _build_args_parser, |
| ) |
| |
| |
| class GraphModuleEvalWrapper(EagerEvalWrapper): |
| """ |
| A wrapper class for ExecuTorch py-binded integration with the |
| lm-evaluation-harness library. |
| """ |
| |
| def __init__( |
| self, |
| model: torch.fx.GraphModule, |
| tokenizer: Union[SentencePieceTokenizer, Tiktoken], |
| max_seq_length: Optional[int] = None, |
| use_kv_cache: bool = False, |
| generate_full_logits: bool = False, |
| enable_dynamic_shape: bool = True, |
| ): |
| super().__init__( |
| model=model, tokenizer=tokenizer, max_seq_length=max_seq_length |
| ) |
| self._model = model.to(self.device) |
| self._use_kv_cache = use_kv_cache |
| self._generate_full_logits = generate_full_logits |
| self._enable_dynamic_shape = enable_dynamic_shape |
| |
| def _model_call(self, inps): |
| if self._use_kv_cache: |
| if not self._enable_dynamic_shape: |
| # graph module exported without dynamic shape won't work with a different shape. |
| # And we have to do single token prefill here. |
| result_logits = [] |
| for pos in range(inps.shape[-1]): |
| pos_tensor = torch.tensor([pos], dtype=torch.int64) |
| logits = self._model(inps[:, pos : pos + 1], pos_tensor) |
| result_logits.append(logits) |
| if self._generate_full_logits: |
| return torch.cat(result_logits, dim=1) |
| else: |
| return torch.stack(result_logits, dim=1) |
| else: |
| pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) |
| # Batch process the whole sequence. |
| logits = self._model(inps[:, : self._max_seq_length], pos_tensor) |
| return logits |
| |
| else: |
| return self._model(inps) |
| |
| def _model_generate(self, context, max_length, eos_token_id): |
| raise Exception("unimplemented") |
| |
| |
| class ETPybindEvalWrapper(EagerEvalWrapper): |
| """ |
| A wrapper class for ExecuTorch py-binded integration with the |
| lm-evaluation-harness library. |
| """ |
| |
| def __init__( |
| self, |
| model: str, |
| tokenizer: Union[SentencePieceTokenizer, Tiktoken], |
| max_seq_length: Optional[int] = None, |
| ): |
| super().__init__(None, tokenizer, max_seq_length) # pyre-ignore |
| self._model = model # Expects model to be path to a .pte file |
| |
| from executorch.extension.pybindings.portable_lib import _load_for_executorch |
| |
| # Load custom ops and quantized ops. |
| from executorch.extension.pybindings import portable_lib # noqa # usort: skip |
| |
| # Note: import this after portable_lib |
| from executorch.extension.llm.custom_ops import ( # noqa |
| sdpa_with_kv_cache, # usort: skip |
| ) |
| from executorch.kernels import quantized # noqa |
| |
| self._et_model = _load_for_executorch(self._model) |
| self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore |
| |
| def _model_call(self, inps): |
| # Given inps (tokens), return the logits from a single forward call |
| # inps: Tensor of shape (1, max_seq_len - 1) |
| # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) |
| result = [] |
| if self._use_kv_cache: |
| pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) |
| result = self._et_model.forward( |
| (inps[:, : self._max_seq_length], pos_tensor) |
| ) |
| else: |
| result = self._et_model.forward((inps,)) |
| if result[0].dim() != 3: |
| raise ValueError( |
| f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." |
| ) |
| return result[0] |
| |
| |
| class ETRunnerEvalWrapper(EagerEvalWrapper): |
| """ |
| A wrapper class for ExecuTorch Runtime integration with the |
| lm-evaluation-harness library. |
| """ |
| |
| def __init__( |
| self, |
| model: str, |
| tokenizer: Union[SentencePieceTokenizer, Tiktoken], |
| tokenizer_bin: str, |
| max_seq_length: Optional[int] = None, |
| ): |
| super().__init__(None, tokenizer, max_seq_length) # pyre-ignore |
| self._model = model |
| self._tokenizer_bin = tokenizer_bin |
| |
| def _model_call(self, inps): |
| # Given inps (tokens), return the logits from a single |
| # forward call |
| |
| # Example: |
| # inps: Tensor of shape (1, N) |
| # logits: Tensor of shape (1, N, vocab_size) |
| pass |
| |
| |
| def gen_eval_wrapper( |
| model_name: str, |
| args: argparse.ArgumentParser, |
| ): |
| """ |
| Generates a wrapper interface around the provided model and tokenizer for |
| the lm-evaluation-harness library. |
| |
| Returns: |
| eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. |
| """ |
| tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore |
| |
| # ExecuTorch Binary Evaluation |
| if (model := args.pte) is not None: # pyre-ignore |
| if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore |
| # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime |
| return ETRunnerEvalWrapper( |
| model=model, |
| tokenizer=tokenizer, |
| tokenizer_bin=tokenizer_bin, |
| max_seq_length=args.max_seq_length, # pyre-ignore |
| ) |
| |
| # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings |
| return ETPybindEvalWrapper( |
| model=model, |
| tokenizer=tokenizer, |
| # Exported model takes at most (max_seq_length - 1) tokens. |
| # Note that the eager model takes at most max_seq_length tokens. |
| max_seq_length=args.max_seq_length - 1, |
| ) |
| |
| pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) |
| # GPTFastEvalWrapper: Create a wrapper around a pre-exported model |
| manager: LLMEdgeManager = _prepare_for_llama_export(args) |
| |
| if len(quantizers) != 0: |
| manager = manager.export().pt2e_quantize(quantizers) |
| model = ( |
| manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore |
| if torch.cuda.is_available() |
| else manager.pre_autograd_graph_module.to(device="cpu") |
| ) |
| return GraphModuleEvalWrapper( |
| model=model, |
| tokenizer=tokenizer, |
| max_seq_length=args.max_seq_length, |
| use_kv_cache=args.use_kv_cache, # pyre-ignore |
| enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore |
| ) |
| else: |
| # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch |
| # for quantizers. Currently export_for_training only works with --kv_cache, but |
| # fails without the kv_cache mode |
| model = ( |
| manager.model.eval().to(device="cuda") |
| if torch.cuda.is_available() |
| else manager.model.eval().to(device="cpu") |
| ) |
| |
| # Save the checkpoint after the eager model preparation is done. |
| # The reason for this option is that the checkpoint can be used |
| # to do evaluations in other evaluation platforms, or with data |
| # that is not available in this eval_llama. We save the checkpoint |
| # here for consistency with eval_llama. The accuracy results we |
| # get from eval_llama can be used as a reference to other evaluations. |
| if args.output_eager_checkpoint_file is not None: # pyre-ignore |
| torch.save(model, args.output_eager_checkpoint_file) |
| |
| return EagerEvalWrapper( |
| model=model, |
| tokenizer=tokenizer, |
| max_seq_length=args.max_seq_length, |
| use_kv_cache=args.use_kv_cache, |
| ) |
| |
| |
| def build_args_parser() -> argparse.ArgumentParser: |
| # Start with arg parser from export_llama_lib |
| parser = _build_args_parser() |
| |
| # Add additional args specific to eval |
| parser.add_argument( |
| "--tasks", |
| nargs="+", |
| type=str, |
| default=["wikitext"], |
| help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", |
| ) |
| parser.add_argument( |
| "--limit", |
| type=int, |
| default=None, |
| help="number of samples to evalulate. If not set, evaluate all samples", |
| ) |
| parser.add_argument( |
| "-f", |
| "--num_fewshot", |
| type=int, |
| default=None, |
| metavar="N", |
| help="Number of examples in few-shot context", |
| ) |
| # Add additional args specific to eval via an ET Runner |
| # Note: For initial integration, the tokenizer.model is also required |
| parser.add_argument( |
| "--pte", |
| type=str, |
| default=None, |
| help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", |
| ) |
| parser.add_argument( |
| "--tokenizer_bin", |
| type=str, |
| default=None, |
| help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", |
| ) |
| parser.add_argument( |
| "--output_eager_checkpoint_file", |
| type=str, |
| default=None, |
| help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", |
| ) |
| |
| return parser |
| |
| |
| def eval_llama( |
| model_name: str, |
| args: argparse.ArgumentParser, |
| ) -> None: |
| # Generate the eval wrapper |
| eval_wrapper = gen_eval_wrapper(model_name, args) |
| |
| # Needed for loading mmlu dataset. |
| # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files |
| # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` |
| if args.tasks and "mmlu" in args.tasks: |
| import datasets |
| |
| datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True |
| |
| # Evaluate the model |
| with torch.no_grad(): |
| eval_results = simple_evaluate( |
| model=eval_wrapper, |
| tasks=args.tasks, |
| num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` |
| limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` |
| ) |
| |
| for task, res in eval_results["results"].items(): |
| print(f"{task}: {res}") |