blob: ce4e30c4ded49f37be81193ddd4243aea8027f81 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import re
from typing import Any, Dict, Mapping
def convert_model_state_dict(
state_dict: Dict[str, Any], key_map: Mapping[str, str]
) -> Dict[str, Any]:
"""Convert a model state dictionary to fairseq2.
:param state_dict:
The original model state dictionary.
:param key_map:
A map of regex patterns to fairseq2 model keys.
:returns:
A converted model state dictionary that is compatible with fairseq2.
"""
new_state_dict = {}
def get_new_key(old_key: str) -> str:
for old_pattern, replacement in key_map.items():
if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
return new_key
return old_key
# Convert module keys from fairseq to fairseq2.
for old_key in state_dict.keys():
new_key = get_new_key(old_key)
new_state_dict[new_key] = state_dict[old_key]
return new_state_dict
def convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a fairseq2 LLaMA checkpoint to the reference format."""
# state_dict = checkpoint["model"]
key_map = {
# fmt: off
r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.",
r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.",
r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.",
r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.",
r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.",
r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.",
r"decoder.layer_norm.": r"norm.",
r"decoder_frontend.embed.": r"tok_embeddings.",
r"final_proj.": r"output.",
# fmt: on
}
return convert_model_state_dict(checkpoint, key_map)