| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # An ExecuTorch friendly implementation of Llava-1.5. |
| |
| import re |
| |
| from typing import Any, Dict, Optional, Tuple |
| |
| import requests |
| import torch |
| from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer |
| |
| from executorch.examples.models.llama2.source_transformation.sdpa import ( |
| replace_sdpa_with_custom_op, |
| ) |
| from executorch.examples.models.llava.image_util import prepare_image |
| from executorch.examples.models.model_base import EagerModelBase |
| from PIL import Image |
| |
| from torch.export import Dim |
| from torchvision.transforms.v2 import functional as F |
| |
| from transformers import ( |
| AutoProcessor, |
| CLIPImageProcessor, |
| LlamaForCausalLM, |
| LlavaForConditionalGeneration, |
| ) |
| |
| |
| class Llava(torch.nn.Module): |
| def __init__( |
| self, |
| llava_model: LlavaForConditionalGeneration, |
| image_processor: CLIPImageProcessor, |
| use_sdpa_with_kv_cache_op: bool = True, |
| max_seq_len: int = 768, |
| ): |
| super().__init__() |
| self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op |
| self.model_ = llava_model |
| self.image_processor = image_processor |
| self.vision_feature_layer = self.model_.config.vision_feature_layer |
| self.vision_feature_select_strategy = ( |
| self.model_.config.vision_feature_select_strategy |
| ) |
| self.text_model_args = ModelArgs( |
| use_kv_cache=True, |
| vocab_size=self.model_.config.text_config.vocab_size, |
| hidden_dim=self.model_.config.text_config.intermediate_size, |
| max_batch_size=1, # doesn't work with default batch size 32 |
| ffn_dim_multiplier=1, # TODO: a hack to make rotary embedding happy |
| enable_dynamic_shape=True, # allow parallel prefill |
| use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op |
| use_hf_rope=True, |
| max_seq_len=max_seq_len, |
| ) |
| self.text_model = Transformer(self.text_model_args) |
| # use custom op for SDPA. |
| if use_sdpa_with_kv_cache_op: |
| self.text_model = replace_sdpa_with_custom_op(self.text_model) |
| # load state dict |
| self.text_model.load_state_dict( |
| state_dict=self._translate_state_dict_for_text_model(), |
| strict=False, |
| assign=True, |
| ) |
| |
| def _translate_state_dict_for_text_model(self) -> Dict[str, Any]: |
| state_dict = self.model_.language_model.state_dict() |
| key_map = { |
| # fmt: off |
| r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", |
| r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", |
| r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", |
| r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", |
| r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", |
| r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", |
| r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", |
| r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", |
| r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", |
| r"model.norm.": r"norm.", |
| # r"model.embed_tokens.": r"tok_embeddings.", # load separately |
| r"lm_head.": r"output.", |
| # fmt: on |
| } |
| |
| new_state_dict = {} |
| |
| def get_new_key(old_key: str) -> str: |
| for old_pattern, replacement in key_map.items(): |
| if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: |
| return new_key |
| |
| return old_key |
| |
| # Convert module keys from hf transformer to Llama transformer. |
| for old_key in state_dict.keys(): |
| new_key = get_new_key(old_key) |
| |
| new_state_dict[new_key] = state_dict[old_key] |
| |
| return new_state_dict |
| |
| def _feature_select(self, image_outputs): |
| selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] |
| |
| if self.vision_feature_select_strategy == "default": |
| selected_image_feature = selected_image_feature[:, 1:] |
| elif self.vision_feature_select_strategy == "full": |
| selected_image_feature = selected_image_feature |
| else: |
| raise ValueError( |
| f"Unexpected select feature: {self.vision_feature_select_strategy}" |
| ) |
| return selected_image_feature |
| |
| def get_model(self): |
| return self.model_.get_model() |
| |
| def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: |
| return self.model_.language_model.model.embed_tokens(tokens) |
| |
| def encode_images(self, images: torch.Tensor) -> torch.Tensor: |
| images = images.to(dtype=self.model_.dtype) |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.model_.vision_tower( |
| image.to( |
| device=self.model_.device, dtype=self.model_.dtype |
| ).unsqueeze(0), |
| output_hidden_states=True, |
| ) |
| image_feature = self._feature_select(image_forward_out).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.model_.vision_tower( |
| images.to(device=self.model_.device, dtype=self.model_.dtype), |
| output_hidden_states=True, |
| ) |
| image_features = self._feature_select(image_forward_outs).to(images.dtype) |
| image_features = self.model_.multi_modal_projector(image_features) |
| return image_features |
| |
| def image_preprocess(self, img: torch.Tensor) -> torch.Tensor: |
| target_h = self.image_processor.crop_size["height"] |
| target_w = self.image_processor.crop_size["width"] |
| # pad the image with median rgb value, to make a square |
| l_pad = (target_w - img.shape[2]) // 2 |
| t_pad = (target_h - img.shape[1]) // 2 |
| # ceil division |
| r_pad = -((target_w - img.shape[2]) // -2) |
| b_pad = -((target_h - img.shape[1]) // -2) |
| |
| torch._check(l_pad >= 0) |
| torch._check(t_pad >= 0) |
| torch._check(r_pad >= 0) |
| torch._check(b_pad >= 0) |
| |
| # This is different from the original implementation, due to export limitations. |
| resized = torch.nn.functional.pad( |
| img, |
| (l_pad, r_pad, t_pad, b_pad), |
| ) |
| # originally: |
| # resized = F.pad( |
| # img, |
| # padding=(l_pad, t_pad, r_pad, b_pad), |
| # fill=tuple(int(x * 255) for x in self.image_mean), |
| # ) |
| |
| # TODO: implement _upsample_bicubic_aa.out in portable kernel library. |
| # here padded shape should be max(h, w) x max(h, w) |
| # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable |
| # resized = resize( |
| # padded, |
| # size=[ |
| # self.image_processor.crop_size["height"], |
| # self.image_processor.crop_size["width"], |
| # ], |
| # interpolation="bicubic", |
| # ) |
| # torch._check(resized.size(1) == self.config.crop_size["height"]) |
| # torch._check(resized.size(2) == self.config.crop_size["width"]) |
| # print(resized.shape) |
| # cropped = F.center_crop(img, output_size=[w, w]) |
| # print(cropped.shape) |
| scaled = resized * self.image_processor.rescale_factor |
| # print(scaled) |
| normed = F.normalize( |
| scaled, self.image_processor.image_mean, self.image_processor.image_std |
| ) |
| # print(normed) |
| return normed.unsqueeze(0) |
| |
| def step( |
| self, token: torch.Tensor, input_pos: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """Input is one token. Return logits for next token.""" |
| token_embeds = self.embed_tokens(token).unsqueeze(0) |
| return self.text_model.forward(None, input_pos, token_embeds) |
| |
| def image_embedding(self, images: torch.Tensor) -> torch.Tensor: |
| preprocessed_img = self.image_preprocess(images) |
| return self.encode_images(preprocessed_img) |
| |
| def prefill_embedding( |
| self, |
| prompt_before_image: torch.Tensor, |
| images: torch.Tensor, |
| prompt_after_image: torch.Tensor, |
| ) -> torch.Tensor: |
| image_embeds = self.image_embedding(images) |
| embeds_before_img = self.embed_tokens(prompt_before_image) |
| embeds_after_img = self.embed_tokens(prompt_after_image) |
| result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1) |
| return result |
| |
| # prefill using the in house text_model of llama transformer |
| def prefill( |
| self, |
| prompt_before_image: torch.Tensor, |
| images: torch.Tensor, |
| prompt_after_image: torch.Tensor, |
| ) -> Tuple[int, torch.Tensor]: |
| """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" |
| embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) |
| # returns the prefilled token length too, because the text model generates one logits in each forward call. |
| return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) |
| |
| # reference prefill using the text model in HF |
| def prefill_ref( |
| self, |
| prompt_before_image: torch.Tensor, |
| images: torch.Tensor, |
| prompt_after_image: torch.Tensor, |
| ) -> torch.Tensor: |
| """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" |
| embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) |
| return LlamaForCausalLM.forward( |
| self.model_.language_model, |
| inputs_embeds=embeds, |
| return_dict=False, |
| use_cache=False, |
| output_hidden_states=False, |
| ) |
| |
| def forward( |
| self, |
| images: torch.Tensor, |
| ) -> torch.Tensor: |
| return self.image_embedding(images) |
| |
| |
| class LlavaModel(EagerModelBase): |
| def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768): |
| self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op |
| self.max_seq_len = max_seq_len |
| self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
| self.tokenizer = self.processor.tokenizer |
| self.image_processor = self.processor.image_processor |
| self.model = LlavaForConditionalGeneration.from_pretrained( |
| "llava-hf/llava-1.5-7b-hf", |
| device_map="cpu", |
| ) |
| self.image = Image.open( |
| requests.get( |
| "https://llava-vl.github.io/static/images/view.jpg", stream=True |
| ).raw |
| ) |
| self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> |
| What are the things I should be cautious about when I visit here? ASSISTANT:""" |
| self.model_name = "llava-1.5-7b-hf" |
| # set input to None and initialize them lazily |
| self.input = None |
| self.resized_image = None |
| |
| def get_eager_model(self): |
| model = Llava( |
| self.model, |
| self.image_processor, |
| self.use_sdpa_with_kv_cache_op, |
| self.max_seq_len, |
| ) |
| model.to(dtype=torch.float32) |
| return model |
| |
| def get_example_inputs(self): |
| """Returns a resized image as input to model.forward().""" |
| if self.resized_image: |
| return self.resized_image |
| resized = prepare_image( |
| self.image, |
| self.image_processor.crop_size["height"], |
| self.image_processor.crop_size["width"], |
| ) |
| self.resized_image = (resized,) |
| return self.resized_image |
| |
| def get_inputs_for_prefill(self): |
| """Returns prompts as well as image.""" |
| if self.input: |
| return self.input |
| self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu() |
| index = torch.where(self.input_ids == self.model.config.image_token_index)[1] |
| self.prompt_before_image = self.input_ids[:, :index] |
| # print(prompt_before_image.shape) |
| self.prompt_after_image = self.input_ids[:, index + 1 :] |
| # print(prompt_after_image.shape) |
| self.input = ( |
| self.prompt_before_image, |
| *self.get_example_inputs(), |
| self.prompt_after_image, |
| ) |
| return self.input |
| |
| def get_dynamic_shapes(self): |
| return self._get_image_dynamic_shapes() |
| |
| def _get_image_dynamic_shapes(self): |
| # only support even number of height and width for now |
| _height = Dim( |
| "_height", min=1, max=self.image_processor.crop_size["height"] // 2 |
| ) |
| _width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2) |
| height = 2 * _height |
| width = 2 * _width |
| dynamic_shapes = [{1: height, 2: width}] |
| return dynamic_shapes |
| |
| def _get_prompt_dynamic_shapes(self): |
| dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len) |
| text_model_dynamic_shapes = ({0: 1}, {1: dim}) |
| return text_model_dynamic_shapes |