blob: dd01365ba59fe6958e9bd07be1467ee4fb8023c7 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Optional, Union
import torch
from executorch.examples.models.llama.export_llama_lib import (
get_quantizer_and_quant_params,
)
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
from executorch.extension.llm.export.builder import LLMEdgeManager
from executorch.extension.llm.tokenizer.tokenizer import (
Tokenizer as SentencePieceTokenizer,
)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from lm_eval.evaluator import simple_evaluate
from .evaluate.eager_eval import EagerEvalWrapper
from .export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
)
class GraphModuleEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
lm-evaluation-harness library.
"""
def __init__(
self,
model: torch.fx.GraphModule,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
generate_full_logits: bool = False,
enable_dynamic_shape: bool = True,
):
super().__init__(
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
)
self._model = model.to(self.device)
self._use_kv_cache = use_kv_cache
self._generate_full_logits = generate_full_logits
self._enable_dynamic_shape = enable_dynamic_shape
def _model_call(self, inps):
if self._use_kv_cache:
if not self._enable_dynamic_shape:
# graph module exported without dynamic shape won't work with a different shape.
# And we have to do single token prefill here.
result_logits = []
for pos in range(inps.shape[-1]):
pos_tensor = torch.tensor([pos], dtype=torch.int64)
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
result_logits.append(logits)
if self._generate_full_logits:
return torch.cat(result_logits, dim=1)
else:
return torch.stack(result_logits, dim=1)
else:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits
else:
return self._model(inps)
def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
class ETPybindEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
lm-evaluation-harness library.
"""
def __init__(
self,
model: str,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
):
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
self._model = model # Expects model to be path to a .pte file
from executorch.extension.pybindings.portable_lib import _load_for_executorch
# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import ( # noqa
sdpa_with_kv_cache, # usort: skip
)
from executorch.kernels import quantized # noqa
self._et_model = _load_for_executorch(self._model)
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore
def _model_call(self, inps):
# Given inps (tokens), return the logits from a single forward call
# inps: Tensor of shape (1, max_seq_len - 1)
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
result = []
if self._use_kv_cache:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
result = self._et_model.forward(
(inps[:, : self._max_seq_length], pos_tensor)
)
else:
result = self._et_model.forward((inps,))
if result[0].dim() != 3:
raise ValueError(
f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits."
)
return result[0]
class ETRunnerEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch Runtime integration with the
lm-evaluation-harness library.
"""
def __init__(
self,
model: str,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
tokenizer_bin: str,
max_seq_length: Optional[int] = None,
):
super().__init__(None, tokenizer, max_seq_length) # pyre-ignore
self._model = model
self._tokenizer_bin = tokenizer_bin
def _model_call(self, inps):
# Given inps (tokens), return the logits from a single
# forward call
# Example:
# inps: Tensor of shape (1, N)
# logits: Tensor of shape (1, N, vocab_size)
pass
def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
):
"""
Generates a wrapper interface around the provided model and tokenizer for
the lm-evaluation-harness library.
Returns:
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
"""
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
# ExecuTorch Binary Evaluation
if (model := args.pte) is not None: # pyre-ignore
if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
return ETRunnerEvalWrapper(
model=model,
tokenizer=tokenizer,
tokenizer_bin=tokenizer_bin,
max_seq_length=args.max_seq_length, # pyre-ignore
)
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
return ETPybindEvalWrapper(
model=model,
tokenizer=tokenizer,
# Exported model takes at most (max_seq_length - 1) tokens.
# Note that the eager model takes at most max_seq_length tokens.
max_seq_length=args.max_seq_length - 1,
)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
manager: LLMEdgeManager = _prepare_for_llama_export(args)
if len(quantizers) != 0:
manager = manager.export().pt2e_quantize(quantizers)
model = (
manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore
if torch.cuda.is_available()
else manager.pre_autograd_graph_module.to(device="cpu")
)
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache, # pyre-ignore
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
# for quantizers. Currently export_for_training only works with --kv_cache, but
# fails without the kv_cache mode
model = (
manager.model.eval().to(device="cuda")
if torch.cuda.is_available()
else manager.model.eval().to(device="cpu")
)
# Save the checkpoint after the eager model preparation is done.
# The reason for this option is that the checkpoint can be used
# to do evaluations in other evaluation platforms, or with data
# that is not available in this eval_llama. We save the checkpoint
# here for consistency with eval_llama. The accuracy results we
# get from eval_llama can be used as a reference to other evaluations.
if args.output_eager_checkpoint_file is not None: # pyre-ignore
torch.save(model, args.output_eager_checkpoint_file)
return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
)
def build_args_parser() -> argparse.ArgumentParser:
# Start with arg parser from export_llama_lib
parser = _build_args_parser()
# Add additional args specific to eval
parser.add_argument(
"--tasks",
nargs="+",
type=str,
default=["wikitext"],
help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="number of samples to evalulate. If not set, evaluate all samples",
)
parser.add_argument(
"-f",
"--num_fewshot",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
# Add additional args specific to eval via an ET Runner
# Note: For initial integration, the tokenizer.model is also required
parser.add_argument(
"--pte",
type=str,
default=None,
help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow",
)
parser.add_argument(
"--tokenizer_bin",
type=str,
default=None,
help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime",
)
parser.add_argument(
"--output_eager_checkpoint_file",
type=str,
default=None,
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
)
return parser
def eval_llama(
model_name: str,
args: argparse.ArgumentParser,
) -> None:
# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, args)
# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
if args.tasks and "mmlu" in args.tasks:
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
# Evaluate the model
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
)
for task, res in eval_results["results"].items():
print(f"{task}: {res}")