blob: 2469ae742df3a5db7282d10f20cef66fa43deae8 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from transformers import PretrainedConfig, StaticCache
class ETStaticCache(StaticCache):
"""
A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch.
This can be removed once transformers supports static cache for Phi3 properly.
"""
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: int,
device,
dtype=torch.float32,
) -> None:
super().__init__(
config=config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()
def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
return self.get_seq_length(layer_idx)