Files
2026-03-10 13:31:25 +08:00

350 lines
14 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import gc
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch_br
from transformers import Qwen2Config
import vllm.model_executor.models.qwen2
from vllm.attention import AttentionType
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import (Qwen2Attention,
Qwen2DecoderLayer, Qwen2Model)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
#import vllm.envs as envs
from vllm_br import envs
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
def Qwen2DecoderLayer__init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen2DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config", None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
attention_bias = getattr(config, "attention_bias", True) or getattr(
config, "bias", True)
tp_size = get_tensor_model_parallel_world_size()
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
# determine whether use qkv merge weights
min_w_gran = 32
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
if is_166 or (config.num_key_value_heads *
(self.hidden_size // config.num_attention_heads)
>= tp_size * spc_num * min_w_gran):
self.self_attn = AttentionSplit(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
bias=attention_bias,
)
logger.debug('[Patch] Use AttentionSplit instead of Qwen2Attention')
else:
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
self.platform = 0
if spc_num > 16:
self.platform = 1
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
logger.info('[Patch] Qwen2 MLP do not merge up/gate weight')
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
qkv_merge = False
for key in params_dict:
if "qkv_proj" in key:
qkv_merge = True
break
if not qkv_merge and len(stacked_params_mapping) >= 3:
stacked_params_mapping = stacked_params_mapping[3:]
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None
and (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight
if loaded_weight.dim() == 0 else loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.platform == 0:
param.data = param.data + 0
if name.find("norm.weight") != -1:
if self.platform == 1:
w_cpu = param.data.to(torch.float32).cpu()
w_supa = torch_br._empty_ut_only(w_cpu.shape,
dtype=w_cpu.dtype,
is_numa=False,
device=param.data.device,
tensor_type="linear_bias",
axis=0,
sbp="BB")
w_supa.copy_(w_cpu)
param.data = w_supa
else:
param.data = param.data.to(torch.float32)
if name.find("embed_tokens.weight") != -1 and self.platform == 1:
w_shape = param.data.shape
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
dtype=param.data.dtype,
is_numa=False,
device=param.data.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
w_supa.copy_(param.data.cpu())
param.data = w_supa
if name.find("lm_head.weight") != -1 and self.platform == 1:
w_shape = param.data.shape
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
dtype=param.data.dtype,
is_numa=False,
device=param.data.device,
tensor_type="colmajor",
axis=0,
sbp="SB")
w_supa.copy_(param.data.cpu())
param.data = w_supa
loaded_params.add(name)
# inference rope sin_cos layout
for _, module in self.named_modules():
rotary_emb = getattr(module, "rotary_emb", None)
if rotary_emb is not None:
if self.platform == 1:
if isinstance(rotary_emb, MRotaryEmbedding):
w_shape = rotary_emb.cos_sin_cache.shape
cos_sin_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.cos_sin_cache.dtype,
is_numa=False,
device=rotary_emb.cos_sin_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
cos_sin_supa.copy_(rotary_emb.cos_sin_cache.cpu())
rotary_emb.cos_sin_cache = cos_sin_supa
else:
w_shape = rotary_emb.sin_cache.shape
sin_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.sin_cache.dtype,
is_numa=False,
device=rotary_emb.sin_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
sin_supa.copy_(rotary_emb.sin_cache.cpu())
rotary_emb.sin_cache = sin_supa
cos_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.cos_cache.dtype,
is_numa=False,
device=rotary_emb.cos_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
cos_supa.copy_(rotary_emb.cos_cache.cpu())
rotary_emb.cos_cache = cos_supa
else:
if isinstance(rotary_emb, MRotaryEmbedding):
rotary_emb.cos_sin_cache = rotary_emb.cos_sin_cache + 0
else:
rotary_emb.sin_cache = rotary_emb.sin_cache + 0
rotary_emb.cos_cache = rotary_emb.cos_cache + 0
torch.supa.synchronize()
gc.collect()
torch.supa.empty_cache()
return loaded_params
def model_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# NOTE: supa wants 3d shape for llm
if len(hidden_states.shape) == 2:
hidden_states = hidden_states.unsqueeze(0)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0) if hidden_states is not None else None,
"residual":
residual.squeeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
# NOTE: convert back to 2D
hidden_states = hidden_states.squeeze()
if hidden_states.dim() == 1:
hidden_states = hidden_states.unsqueeze(0)
return hidden_states
def Qwen2Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
q, k, v = torch_br.split_w_sbp_infer(
qkv, [self.q_size, self.kv_size, self.kv_size])
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
vllm.model_executor.models.qwen2.Qwen2DecoderLayer.__init__ = Qwen2DecoderLayer__init__
logger.debug('[Patch] patch Qwen2 MLP with LlaMA_MLP_SiLU_3L')
Qwen2Model.load_weights = load_weights
Qwen2Model.forward = model_forward
Qwen2Attention.forward = Qwen2Attention_forward