################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from collections.abc import Iterable from typing import Optional import torch import torch_br from fastcore.basics import patch_to from transformers import Qwen3Config import vllm.model_executor.models.qwen3 from vllm.attention import AttentionType from vllm.config import CacheConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen3 import (Qwen3Attention, Qwen3DecoderLayer, Qwen3Model) from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm_br.v1.attention.backends.attention_v1 import ( SUPAFlashAttentionMetadata) from .qwen2 import model_forward from .supa_module import MergedGateUpMLPSiluL2 @patch_to(vllm.model_executor.models.qwen3.Qwen3Attention) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata if attn_metadata is None: ## for dummy run return hidden_states seq_len = hidden_states.shape[-2] decode_seql = 512 if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.attn.layer_name] kv_cache = self.attn.kv_cache[forward_context.virtual_engine] if kv_cache is not None: if seq_len <= decode_seql: if hasattr(self.qkv_proj, "qweight"): qkv_weight = self.qkv_proj.qweight.data qkv_scales = self.qkv_proj.scales.data elif hasattr(self.qkv_proj, "weight_packed"): qkv_weight = self.qkv_proj.weight_packed.data qkv_scales = self.qkv_proj.weight_scale.data else: qkv_weight = self.qkv_proj.weight qkv_scales = None if isinstance(self.rotary_emb, MRotaryEmbedding): assert len( self.rotary_emb.mrope_section ) == 3 and self.rotary_emb.mrope_section[ 1] == self.rotary_emb.mrope_section[ 2], "current only support mrope_section width and height are equal!" q, k, v = torch_br.br_qwen3_vl_prefix_attn_infer( hidden_states, qkv_weight, [self.q_size, self.kv_size, self.kv_size], self.head_dim, self.q_norm.variance_epsilon, self.q_norm.weight, self.k_norm.weight, self.rotary_emb.cos_sin_cache, kv_cache, positions, attn_metadata.slot_mapping, self.rotary_emb.mrope_section[1], bias=self.qkv_proj.bias, scales=qkv_scales) else: q, k, v = torch_br.br_qwen3_prefix_attn_infer( hidden_states, qkv_weight, [self.q_size, self.kv_size, self.kv_size], self.head_dim, self.q_norm.variance_epsilon, self.q_norm.weight, self.k_norm.weight, self.rotary_emb.sin_cache, self.rotary_emb.cos_cache, kv_cache, positions, attn_metadata.slot_mapping, bias=self.qkv_proj.bias, scales=qkv_scales) else: qkv, _ = self.qkv_proj(hidden_states) if isinstance(self.rotary_emb, MRotaryEmbedding): assert len( self.rotary_emb.mrope_section ) == 3 and self.rotary_emb.mrope_section[ 1] == self.rotary_emb.mrope_section[ 2], "current only support mrope_section width and height are equal!" q, k, v = torch_br.br_fused_rms_mrope_kvstore_infer( qkv, [self.q_size, self.kv_size, self.kv_size], self.head_dim, self.q_norm.variance_epsilon, self.q_norm.weight, self.k_norm.weight, self.rotary_emb.cos_sin_cache, kv_cache, positions, attn_metadata.slot_mapping, attn_metadata.block_table, attn_metadata.query_start_loc, attn_metadata.context_lens, self.rotary_emb.mrope_section[1]) else: q, k, v = torch_br.br_fused_rms_rope_kvstore_infer( qkv, [self.q_size, self.kv_size, self.kv_size], self.head_dim, self.q_norm.variance_epsilon, self.q_norm.weight, self.k_norm.weight, self.rotary_emb.sin_cache, self.rotary_emb.cos_cache, kv_cache, positions, attn_metadata.slot_mapping, attn_metadata.block_table, attn_metadata.query_start_loc, attn_metadata.context_lens) if hasattr(attn_metadata, 'do_cache'): attn_metadata.do_cache = False attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output else: return hidden_states def Qwen3DecoderLayer__init__( self, config: Qwen3Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super(Qwen3DecoderLayer, self).__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) # By default, Qwen3 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable # bidirectional attention, which is used in some embedding models # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct) if getattr(config, "is_causal", True): attn_type = AttentionType.DECODER else: attn_type = AttentionType.ENCODER_ONLY self.self_attn = Qwen3Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, 'attention_bias', False), head_dim=getattr(config, 'head_dim', None), cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, ) self.mlp = MergedGateUpMLPSiluL2( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if name.find("norm.weight") != -1: param.data = param.data.to(torch.float32) loaded_params.add(name) return loaded_params vllm.model_executor.models.qwen3.Qwen3DecoderLayer.__init__ = Qwen3DecoderLayer__init__ logger.debug('[Patch] patch Qwen3 MLP with MergedGateUpMLPSiluL2') Qwen3Model.load_weights = load_weights Qwen3Model.forward = model_forward