Sync from v0.13
This commit is contained in:
154
vllm/model_executor/models/fairseq2_llama.py
Normal file
154
vllm/model_executor/models/fairseq2_llama.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2024 The vLLM team.
|
||||
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Llama model for fairseq2 weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import set_weight_attrs
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
|
||||
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# For the model loader to read only the relevant checkpoint files
|
||||
self.allow_patterns_overrides = [
|
||||
# either the full checkpoint
|
||||
"model.pt",
|
||||
# or the tp-sharded checkpoint of the current rank
|
||||
f"model.{self.tp_rank}.pt",
|
||||
]
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
|
||||
# { "model_key": my_model_name, "my_model_name": state_dict }
|
||||
# which we first need to unpack
|
||||
weights_wrapped = dict(weights)
|
||||
weights = weights_wrapped[weights_wrapped["model_key"]].items() # type: ignore
|
||||
|
||||
# remap keys
|
||||
fs2_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"decoder_frontend.embed.": "model.embed_tokens.",
|
||||
"decoder.": "model.",
|
||||
"final_proj.": "lm_head.",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".self_attn_layer_norm.": ".input_layernorm.",
|
||||
".ffn_layer_norm.": ".post_attention_layernorm.",
|
||||
".self_attn.output_proj.": ".self_attn.o_proj.",
|
||||
".ffn.gate_proj.": ".mlp.gate_proj.",
|
||||
".ffn.inner_proj.": ".mlp.up_proj.",
|
||||
".ffn.output_proj.": ".mlp.down_proj.",
|
||||
".layer_norm.": ".norm.",
|
||||
},
|
||||
)
|
||||
weights = fs2_to_vllm_mapper.apply(weights)
|
||||
|
||||
params = dict(self.named_parameters())
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
(
|
||||
self.reshape_fairseq2_weights(name, loaded_weight, params)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
)
|
||||
|
||||
def flag_sharded_weights(self, params: dict[str, Parameter]):
|
||||
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
|
||||
for name, param in params.items():
|
||||
modules = name.split(".")
|
||||
if "norm" in name and len(param.size()) < 2:
|
||||
# layer norms are not sharded
|
||||
continue
|
||||
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
||||
# for now we repeat embedding layers for compatibility
|
||||
continue
|
||||
else:
|
||||
# all other layers are sharded
|
||||
set_weight_attrs(param, {"is_sharded_weight": True})
|
||||
|
||||
def reshape_fairseq2_weights(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
params: dict[str, Parameter],
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
"""Reshape fairseq2's weights."""
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
# check for a sharded weight on dim 0
|
||||
if attn_in // self.tp_size == w.size()[0]:
|
||||
attn_in //= self.tp_size
|
||||
n_heads //= self.tp_size
|
||||
attn_out = self.config.hidden_size
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "k_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads)
|
||||
|
||||
elif "q_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads)
|
||||
|
||||
# We make the loaded weights compatible with both
|
||||
# full checkpoints and tp sharded checkpoints.
|
||||
# Embeddings are repeated to fit the vocab size.
|
||||
# Other weights are flagged for the weight_loader calls.
|
||||
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
||||
# Embeddings are sharded on dim 0
|
||||
dim = 0
|
||||
# In fairseq2, vocab size has to be divisible by tp_size
|
||||
# so we don't worry about padding
|
||||
if self.tp_size > 1 and loaded_weight.shape[dim] < self.config.vocab_size:
|
||||
assert (
|
||||
loaded_weight.shape[dim] * self.tp_size == self.config.vocab_size
|
||||
), "vocab_size should be divisible by tp_size."
|
||||
repeats = [1] * len(loaded_weight.size())
|
||||
repeats[dim] = self.tp_size
|
||||
# repeat to match vocab size and to be easily 'narrow'able
|
||||
loaded_weight = loaded_weight.repeat(repeats)
|
||||
set_weight_attrs(params[name], {"is_sharded_weight": False})
|
||||
# if embeddings are sharded, the rest is too
|
||||
if "embed_tokens" in modules:
|
||||
self.flag_sharded_weights(params)
|
||||
|
||||
return name, loaded_weight
|
||||
Reference in New Issue
Block a user