################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ import gc from collections.abc import Iterable from typing import Optional, Union import torch import torch_br from transformers import Qwen2Config import vllm.model_executor.models.qwen2 from vllm.attention import AttentionType from vllm.config import CacheConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2DecoderLayer, Qwen2Model) from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm.sequence import IntermediateTensors #import vllm.envs as envs from vllm_br import envs from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2 def Qwen2DecoderLayer__init__( self, config: Qwen2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super(Qwen2DecoderLayer, self).__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) dual_chunk_attention_config = getattr(config, "dual_chunk_attention_config", None) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable # bidirectional attention, which is used in some embedding models # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) if getattr(config, "is_causal", True): attn_type = AttentionType.DECODER else: attn_type = AttentionType.ENCODER_ONLY attention_bias = getattr(config, "attention_bias", True) or getattr( config, "bias", True) tp_size = get_tensor_model_parallel_world_size() spc_num = torch_br.supa.get_device_properties("supa").max_compute_units # determine whether use qkv merge weights min_w_gran = 32 is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16 # NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit if is_166 or (config.num_key_value_heads * (self.hidden_size // config.num_attention_heads) >= tp_size * spc_num * min_w_gran): self.self_attn = AttentionSplit( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", bias=attention_bias, ) logger.debug('[Patch] Use AttentionSplit instead of Qwen2Attention') else: self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = MergedGateUpMLPSiluL2( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: spc_num = envs.VLLM_BR_DEVICE_SPC_NUM self.platform = 0 if spc_num > 16: self.platform = 1 stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] logger.info('[Patch] Qwen2 MLP do not merge up/gate weight') params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: qkv_merge = False for key in params_dict: if "qkv_proj" in key: qkv_merge = True break if not qkv_merge and len(stacked_params_mapping) >= 3: stacked_params_mapping = stacked_params_mapping[3:] if "rotary_emb.inv_freq" in name: continue if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if self.platform == 0: param.data = param.data + 0 if name.find("norm.weight") != -1: if self.platform == 1: w_cpu = param.data.to(torch.float32).cpu() w_supa = torch_br._empty_ut_only(w_cpu.shape, dtype=w_cpu.dtype, is_numa=False, device=param.data.device, tensor_type="linear_bias", axis=0, sbp="BB") w_supa.copy_(w_cpu) param.data = w_supa else: param.data = param.data.to(torch.float32) if name.find("embed_tokens.weight") != -1 and self.platform == 1: w_shape = param.data.shape w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]), dtype=param.data.dtype, is_numa=False, device=param.data.device, tensor_type="colmajor", axis=0, sbp="BB") w_supa.copy_(param.data.cpu()) param.data = w_supa if name.find("lm_head.weight") != -1 and self.platform == 1: w_shape = param.data.shape w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]), dtype=param.data.dtype, is_numa=False, device=param.data.device, tensor_type="colmajor", axis=0, sbp="SB") w_supa.copy_(param.data.cpu()) param.data = w_supa loaded_params.add(name) # inference rope sin_cos layout for _, module in self.named_modules(): rotary_emb = getattr(module, "rotary_emb", None) if rotary_emb is not None: if self.platform == 1: if isinstance(rotary_emb, MRotaryEmbedding): w_shape = rotary_emb.cos_sin_cache.shape cos_sin_supa = torch_br._empty_ut_only( size=(w_shape[0], w_shape[1]), dtype=rotary_emb.cos_sin_cache.dtype, is_numa=False, device=rotary_emb.cos_sin_cache.device, tensor_type="colmajor", axis=0, sbp="BB") cos_sin_supa.copy_(rotary_emb.cos_sin_cache.cpu()) rotary_emb.cos_sin_cache = cos_sin_supa else: w_shape = rotary_emb.sin_cache.shape sin_supa = torch_br._empty_ut_only( size=(w_shape[0], w_shape[1]), dtype=rotary_emb.sin_cache.dtype, is_numa=False, device=rotary_emb.sin_cache.device, tensor_type="colmajor", axis=0, sbp="BB") sin_supa.copy_(rotary_emb.sin_cache.cpu()) rotary_emb.sin_cache = sin_supa cos_supa = torch_br._empty_ut_only( size=(w_shape[0], w_shape[1]), dtype=rotary_emb.cos_cache.dtype, is_numa=False, device=rotary_emb.cos_cache.device, tensor_type="colmajor", axis=0, sbp="BB") cos_supa.copy_(rotary_emb.cos_cache.cpu()) rotary_emb.cos_cache = cos_supa else: if isinstance(rotary_emb, MRotaryEmbedding): rotary_emb.cos_sin_cache = rotary_emb.cos_sin_cache + 0 else: rotary_emb.sin_cache = rotary_emb.sin_cache + 0 rotary_emb.cos_cache = rotary_emb.cos_cache + 0 torch.supa.synchronize() gc.collect() torch.supa.empty_cache() return loaded_params def model_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] # NOTE: supa wants 3d shape for llm if len(hidden_states.shape) == 2: hidden_states = hidden_states.unsqueeze(0) for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, hidden_states, residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states.squeeze(0) if hidden_states is not None else None, "residual": residual.squeeze(0) if residual is not None else None }) hidden_states, _ = self.norm(hidden_states, residual) # NOTE: convert back to 2D hidden_states = hidden_states.squeeze() if hidden_states.dim() == 1: hidden_states = hidden_states.unsqueeze(0) return hidden_states def Qwen2Attention_forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) if envs.VLLM_BR_DEVICE_SPC_NUM > 16: q, k, v = torch_br.split_w_sbp_infer( qkv, [self.q_size, self.kv_size, self.kv_size]) else: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output vllm.model_executor.models.qwen2.Qwen2DecoderLayer.__init__ = Qwen2DecoderLayer__init__ logger.debug('[Patch] patch Qwen2 MLP with LlaMA_MLP_SiLU_3L') Qwen2Model.load_weights = load_weights Qwen2Model.forward = model_forward Qwen2Attention.forward = Qwen2Attention_forward