blob: a7c906f0cdc9929e1ae05850b77d66beef265a36 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn
from transformers import Phi3ForCausalLM
from .static_cache import ETStaticCache
class Phi3Mini(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = ETStaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)
def forward(
self,
input_ids: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
).logits[:, -1, :]