blob: 8b57b5a24c928b7fea850195df0d0b6e7164e4b2 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Script to run phi-3-mini model in eager mode.
import argparse
import time
import torch
from transformers import AutoTokenizer, Phi3ForCausalLM
from .phi_3_mini import Phi3Mini
end_of_text_token = 32000
def _generate_token(args, model, prompt_tokens):
current_token = 0
generated_tokens = []
print("Generating tokens:", end="", flush=True)
while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
outputs = model.forward(input_ids=prompt_tokens)
current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)
prompt_tokens = torch.cat(
[prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1
)
print("", flush=True)
return generated_tokens
def _generate_token_with_kv_cache(args, model, prompt_tokens):
print("Generating tokens:", end="", flush=True)
model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])
result = model.forward(input_ids=prompt_tokens)
current_token = torch.argmax(result, dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens = [current_token]
while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
)
current_token = torch.argmax(result, dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)
print("", flush=True)
return generated_tokens
def main(args):
seed = 42
torch.manual_seed(seed)
model_name = "microsoft/Phi-3-mini-4k-instruct"
model = Phi3ForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode(args.prompt, return_tensors="pt")
start = time.time()
generated_tokens = (
_generate_token_with_kv_cache(args, model, tokens)
if args.use_kv_cache
else _generate_token(args, model, tokens)
)
end = time.time()
print(
"Generated response: \n {}".format(
tokenizer.decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
),
flush=True,
)
print(f"Time spent: {end - start}", flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--seq_len",
type=int,
default=128,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"-kv",
"--use_kv_cache",
default=False,
action="store_true",
help="Whether or not to use KV cache",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
default="Tell me a story",
help="Prompt as input for the model",
)
main(parser.parse_args())