Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
def get_module_name_from_weight_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
embed_tokens_weight = None
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
load_tie_word_embeddings = (
|
||||
@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
if embed_tokens_weight is not None:
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
try:
|
||||
mapped_name = name
|
||||
mapped_shard_id = None
|
||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
mapped_name = name.replace(weight_name, param_name)
|
||||
mapped_shard_id = shard_id
|
||||
break
|
||||
params_dict = dict(self.named_parameters())
|
||||
if mapped_name in params_dict:
|
||||
param = params_dict[mapped_name]
|
||||
if mapped_shard_id is not None:
|
||||
if mapped_shard_id in ["q", "k", "v"]:
|
||||
num_heads = self.config.num_attention_heads // tp_size
|
||||
num_kv_heads = self.config.num_key_value_heads // tp_size
|
||||
head_dim = (
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
)
|
||||
if mapped_shard_id == "q":
|
||||
offset = 0
|
||||
size = num_heads * head_dim
|
||||
elif mapped_shard_id == "k":
|
||||
offset = num_heads * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
elif mapped_shard_id == "v":
|
||||
offset = (num_heads + num_kv_heads) * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
elif mapped_shard_id in [0, 1]:
|
||||
intermediate_size = self.config.intermediate_size
|
||||
hidden_size = self.config.hidden_size
|
||||
slice_size = intermediate_size // tp_size
|
||||
if mapped_shard_id == 0: # gate_proj
|
||||
offset = 0
|
||||
size = slice_size
|
||||
elif mapped_shard_id == 1: # up_proj
|
||||
offset = slice_size
|
||||
size = slice_size
|
||||
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
else:
|
||||
weight = param.data
|
||||
else:
|
||||
weight = param.data
|
||||
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
||||
gathered_weights = [
|
||||
torch.zeros_like(weight) for _ in range(tp_size)
|
||||
]
|
||||
torch.distributed.all_gather(gathered_weights, weight)
|
||||
weight = torch.cat(gathered_weights, dim=1)
|
||||
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user