368 lines
15 KiB
Python
368 lines
15 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
|
# Copyright 2023 The vLLM team.
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch_br
|
|
from fastcore.basics import patch_to
|
|
from transformers import LlamaConfig
|
|
|
|
import vllm.model_executor.models.llama
|
|
from vllm.attention import Attention, AttentionType
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader, maybe_remap_kv_scale_name)
|
|
from vllm.model_executor.models.llama import (LlamaAttention,
|
|
LlamaDecoderLayer,
|
|
LlamaForCausalLM, LlamaModel)
|
|
from vllm.model_executor.models.utils import (extract_layer_index,
|
|
is_pp_missing_parameter)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm_br import envs
|
|
from ..layers.quantization.compressed_tensors.utils import (
|
|
get_compressed_tensors_cache_scale)
|
|
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
|
|
|
|
|
|
def LlamaDecoderLayer__init__(self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
config: Optional[LlamaConfig] = None) -> None:
|
|
super(LlamaDecoderLayer, self).__init__()
|
|
config = config or vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
if rope_scaling is not None and getattr(
|
|
config, "original_max_position_embeddings", None):
|
|
rope_scaling["original_max_position_embeddings"] = (
|
|
config.original_max_position_embeddings)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
|
# Support internlm/internlm-7b with bias
|
|
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
|
config, "bias", False)
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
|
|
# determine whether use qkv merge weights
|
|
min_w_gran = 32
|
|
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
|
|
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
|
|
if is_166 or (config.num_key_value_heads *
|
|
(self.hidden_size // config.num_attention_heads)
|
|
>= tp_size * spc_num * min_w_gran):
|
|
self.self_attn = AttentionSplit(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=getattr(config, "num_key_value_heads",
|
|
config.num_attention_heads),
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
bias=attention_bias,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
else:
|
|
self.self_attn = LlamaAttention(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=getattr(config, "num_key_value_heads",
|
|
config.num_attention_heads),
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
bias=attention_bias,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
|
|
self.mlp = MergedGateUpMLPSiluL2(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
bias=getattr(config, "mlp_bias", False),
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
loaded_params = []
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
# determine whether is qkv merge weights
|
|
qkv_merge = False
|
|
for key in params_dict:
|
|
if "qkv_proj" in key:
|
|
qkv_merge = True
|
|
break
|
|
if not qkv_merge and len(stacked_params_mapping) >= 3:
|
|
stacked_params_mapping = stacked_params_mapping[3:]
|
|
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if ("rotary_emb.cos_cached" in name
|
|
or "rotary_emb.sin_cached" in name):
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
if scale_name := get_compressed_tensors_cache_scale(name):
|
|
# Loading kv cache scales for compressed-tensors quantization
|
|
param = params_dict[scale_name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
loaded_weight = loaded_weight[0]
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.append(scale_name)
|
|
continue
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
# weight layout infer
|
|
param.data = param.data + 0
|
|
loaded_params.append(name)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
# weight layout infer
|
|
param.data = param.data + 0
|
|
if name.find("norm.weight") != -1:
|
|
param.data = param.data.to(torch.float32)
|
|
loaded_params.append(name)
|
|
|
|
return set(loaded_params)
|
|
|
|
|
|
def llamamodel_forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
|
list[torch.Tensor]]]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
residual = residual.unsqueeze(0)
|
|
|
|
aux_hidden_states = []
|
|
for idx, layer in enumerate(self.layers[self.start_layer:self.end_layer]):
|
|
if idx in self.aux_hidden_state_layers:
|
|
aux_hidden_states.append(hidden_states + residual)
|
|
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states":
|
|
hidden_states.squeeze(0)
|
|
if hidden_states is not None else hidden_states,
|
|
"residual":
|
|
residual.squeeze(0) if residual is not None else residual
|
|
})
|
|
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
|
|
if len(aux_hidden_states) > 0:
|
|
return hidden_states, aux_hidden_states
|
|
return hidden_states.squeeze(0)
|
|
|
|
|
|
def LlamaAttention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
|
q, k, v = torch_br.split_w_sbp_infer(
|
|
qkv, [self.q_size, self.kv_size, self.kv_size])
|
|
else:
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
@patch_to(LlamaAttention)
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
bias: bool = False,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
prefix: str = "",
|
|
dual_chunk_attention_config: Optional[dict[str, Any]] = None) -> None:
|
|
super(LlamaAttention, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
|
self.head_dim = getattr(config, "head_dim",
|
|
self.hidden_size // self.total_num_heads)
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
qconfig = None
|
|
if quant_config is not None and quant_config.qkv_quantized:
|
|
qconfig = quant_config
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
total_num_kv_heads=self.total_num_kv_heads,
|
|
bias=bias,
|
|
quant_config=qconfig,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=self.total_num_heads * self.head_dim,
|
|
output_size=hidden_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
attn_type=attn_type,
|
|
prefix=f"{prefix}.attn",
|
|
**{
|
|
"layer_idx": extract_layer_index(prefix),
|
|
"dual_chunk_attention_config": dual_chunk_attention_config,
|
|
} if dual_chunk_attention_config else {})
|
|
|
|
|
|
vllm.model_executor.models.llama.LlamaDecoderLayer.__init__ = LlamaDecoderLayer__init__
|
|
LlamaForCausalLM.load_weights = load_weights
|
|
LlamaModel.forward = llamamodel_forward
|
|
LlamaAttention.forward = LlamaAttention_forward
|