################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch_br from fastcore.basics import patch_to from transformers import LlamaConfig import vllm.model_executor.models.llama from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.llama import (LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel) from vllm.model_executor.models.utils import (extract_layer_index, is_pp_missing_parameter) from vllm.sequence import IntermediateTensors from vllm_br import envs from ..layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2 def LlamaDecoderLayer__init__(self, vllm_config: VllmConfig, prefix: str = "", config: Optional[LlamaConfig] = None) -> None: super(LlamaDecoderLayer, self).__init__() config = config or vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) tp_size = get_tensor_model_parallel_world_size() spc_num = torch_br.supa.get_device_properties("supa").max_compute_units # determine whether use qkv merge weights min_w_gran = 32 is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16 # NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit if is_166 or (config.num_key_value_heads * (self.hidden_size // config.num_attention_heads) >= tp_size * spc_num * min_w_gran): self.self_attn = AttentionSplit( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position=max_position_embeddings, quant_config=quant_config, bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", ) else: self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", ) self.mlp = MergedGateUpMLPSiluL2( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params = [] stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) # determine whether is qkv merge weights qkv_merge = False for key in params_dict: if "qkv_proj" in key: qkv_merge = True break if not qkv_merge and len(stacked_params_mapping) >= 3: stacked_params_mapping = stacked_params_mapping[3:] for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if scale_name := get_compressed_tensors_cache_scale(name): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.append(scale_name) continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) # weight layout infer param.data = param.data + 0 loaded_params.append(name) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # weight layout infer param.data = param.data + 0 if name.find("norm.weight") != -1: param.data = param.data.to(torch.float32) loaded_params.append(name) return set(loaded_params) def llamamodel_forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None hidden_states = hidden_states.unsqueeze(0) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] hidden_states = hidden_states.unsqueeze(0) residual = residual.unsqueeze(0) aux_hidden_states = [] for idx, layer in enumerate(self.layers[self.start_layer:self.end_layer]): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states.squeeze(0) if hidden_states is not None else hidden_states, "residual": residual.squeeze(0) if residual is not None else residual }) hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states.squeeze(0) def LlamaAttention_forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) if envs.VLLM_BR_DEVICE_SPC_NUM > 16: q, k, v = torch_br.split_w_sbp_infer( qkv, [self.q_size, self.kv_size, self.kv_size]) else: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @patch_to(LlamaAttention) def __init__( self, config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, attn_type: str = AttentionType.DECODER, prefix: str = "", dual_chunk_attention_config: Optional[dict[str, Any]] = None) -> None: super(LlamaAttention, self).__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings qconfig = None if quant_config is not None and quant_config.qkv_quantized: qconfig = quant_config self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=qconfig, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, attn_type=attn_type, prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, } if dual_chunk_attention_config else {}) vllm.model_executor.models.llama.LlamaDecoderLayer.__init__ = LlamaDecoderLayer__init__ LlamaForCausalLM.load_weights = load_weights LlamaModel.forward = llamamodel_forward LlamaAttention.forward = LlamaAttention_forward