| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| |
| # Script to run phi-3-mini model in eager mode. |
| |
| import argparse |
| import time |
| |
| import torch |
| |
| from transformers import AutoTokenizer, Phi3ForCausalLM |
| |
| from .phi_3_mini import Phi3Mini |
| |
| end_of_text_token = 32000 |
| |
| |
| def _generate_token(args, model, prompt_tokens): |
| current_token = 0 |
| generated_tokens = [] |
| |
| print("Generating tokens:", end="", flush=True) |
| |
| while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: |
| outputs = model.forward(input_ids=prompt_tokens) |
| current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item() |
| print(f" {current_token}", end="", flush=True) |
| generated_tokens.append(current_token) |
| prompt_tokens = torch.cat( |
| [prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1 |
| ) |
| |
| print("", flush=True) |
| |
| return generated_tokens |
| |
| |
| def _generate_token_with_kv_cache(args, model, prompt_tokens): |
| print("Generating tokens:", end="", flush=True) |
| |
| model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1]) |
| result = model.forward(input_ids=prompt_tokens) |
| |
| current_token = torch.argmax(result, dim=-1).item() |
| print(f" {current_token}", end="", flush=True) |
| generated_tokens = [current_token] |
| |
| while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: |
| result = model.forward( |
| input_ids=torch.tensor([[current_token]], dtype=torch.long), |
| ) |
| current_token = torch.argmax(result, dim=-1).item() |
| print(f" {current_token}", end="", flush=True) |
| generated_tokens.append(current_token) |
| |
| print("", flush=True) |
| |
| return generated_tokens |
| |
| |
| def main(args): |
| seed = 42 |
| torch.manual_seed(seed) |
| model_name = "microsoft/Phi-3-mini-4k-instruct" |
| model = Phi3ForCausalLM.from_pretrained(model_name) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
| tokens = tokenizer.encode(args.prompt, return_tensors="pt") |
| |
| start = time.time() |
| generated_tokens = ( |
| _generate_token_with_kv_cache(args, model, tokens) |
| if args.use_kv_cache |
| else _generate_token(args, model, tokens) |
| ) |
| end = time.time() |
| |
| print( |
| "Generated response: \n {}".format( |
| tokenizer.decode( |
| generated_tokens, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| ) |
| ), |
| flush=True, |
| ) |
| print(f"Time spent: {end - start}", flush=True) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-s", |
| "--seq_len", |
| type=int, |
| default=128, |
| help="Maximum number of tokens to generate", |
| ) |
| parser.add_argument( |
| "-kv", |
| "--use_kv_cache", |
| default=False, |
| action="store_true", |
| help="Whether or not to use KV cache", |
| ) |
| parser.add_argument( |
| "-p", |
| "--prompt", |
| type=str, |
| default="Tell me a story", |
| help="Prompt as input for the model", |
| ) |
| main(parser.parse_args()) |