Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
logger = logging.getLogger(__name__)
class LlamaMLP(nn.Module):
def __init__(
@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
@torch.no_grad()
def forward(
@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard)
("qkv_proj", "q_proj", "q", 3),
("qkv_proj", "k_proj", "k", 3),
("qkv_proj", "v_proj", "v", 3),
("gate_up_proj", "gate_proj", 0, 2),
("gate_up_proj", "up_proj", 1, 2),
]
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name:
return (
name.replace(weight_name, param_name)[: -len(".weight")],
@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
return len(params_dict)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embed_tokens_weight)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
def get_weights_by_name(
self, name: str, truncate_size: int = 100, tp_size: int = 1
) -> Optional[torch.Tensor]:
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance.
For optimized performance, please use torch.save and torch.load.
"""
try:
mapped_name = name
mapped_shard_id = None
for param_name, weight_name, shard_id in self.stacked_params_mapping:
if weight_name in name:
mapped_name = name.replace(weight_name, param_name)
mapped_shard_id = shard_id
break
params_dict = dict(self.named_parameters())
if mapped_name in params_dict:
param = params_dict[mapped_name]
if mapped_shard_id is not None:
if mapped_shard_id in ["q", "k", "v"]:
num_heads = self.config.num_attention_heads // tp_size
num_kv_heads = self.config.num_key_value_heads // tp_size
head_dim = (
self.config.hidden_size // self.config.num_attention_heads
)
if mapped_shard_id == "q":
offset = 0
size = num_heads * head_dim
elif mapped_shard_id == "k":
offset = num_heads * head_dim
size = num_kv_heads * head_dim
elif mapped_shard_id == "v":
offset = (num_heads + num_kv_heads) * head_dim
size = num_kv_heads * head_dim
weight = param.data.narrow(0, offset, size)
elif mapped_shard_id in [0, 1]:
intermediate_size = self.config.intermediate_size
hidden_size = self.config.hidden_size
slice_size = intermediate_size // tp_size
if mapped_shard_id == 0: # gate_proj
offset = 0
size = slice_size
elif mapped_shard_id == 1: # up_proj
offset = slice_size
size = slice_size
weight = param.data.narrow(0, offset, size)
else:
weight = param.data
else:
weight = param.data
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
gathered_weights = [
torch.zeros_like(weight) for _ in range(tp_size)
]
torch.distributed.all_gather(gathered_weights, weight)
weight = torch.cat(gathered_weights, dim=1)
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
else:
return None
except Exception as e:
logger.error(
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
)
return None
class Phi3ForCausalLM(LlamaForCausalLM):
pass