blob: 62757506f3bdbb232ef515017a94fc1f175b652d [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
from typing import Optional
import torch
from executorch.examples.models.llama.export_llama_lib import (
EXECUTORCH_DEFINED_MODELS,
TORCHTUNE_DEFINED_MODELS,
)
from executorch.extension.pybindings.portable_lib import _load_for_executorch
# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
from executorch.examples.models.llama.runner.generation import LlamaRunner
# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.kernels import quantized # noqa
class NativeLlamaRunner(LlamaRunner):
"""
Runs llama via ExecuTorch with provided pte file.
"""
def __init__(self, args):
with open(args.params, "r") as f:
params = json.loads(f.read())
super().__init__(
tokenizer_path=args.tokenizer,
max_seq_len=args.max_len,
max_batch_size=1,
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
def forward(
self,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return (
self.model.forward((tokens, input_pos))
if input_pos is not None
else self.model.forward((tokens,))
)[0]
def build_args_parser() -> argparse.ArgumentParser:
# TODO: merge these with build_args_parser from export_llama_lib.
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="llama3",
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
)
parser.add_argument(
"-f",
"--pte",
type=str,
default=None,
help="path to exported executorch .pte file",
)
parser.add_argument(
"-p", "--params", type=str, default=None, help="model params file"
)
parser.add_argument(
"-t",
"--tokenizer",
type=str,
default=None,
)
parser.add_argument(
"--prompt",
type=str,
default="Hello",
)
parser.add_argument(
"--temperature",
type=float,
default=0.6,
)
parser.add_argument(
"-kv",
"--kv_cache",
action="store_true",
)
parser.add_argument(
"--max_len",
type=int,
default=128,
help="Maximum length of the generated response sequence.",
)
return parser
def main() -> None:
parser = build_args_parser()
args = parser.parse_args()
runner = NativeLlamaRunner(args)
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
print(f"Response: {generated_tokens}")
if __name__ == "__main__":
main() # pragma: no cover