"""Inference-only Qwen3 model compatible with HuggingFace weights.""" from collections.abc import Iterable from typing import Optional, Union, Any, Dict import torch from torch import nn from vllm.logger import init_logger from .vars import * from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.forward_context import ForwardContext, get_forward_context from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size logger = init_logger(__name__) # uniform the params names from different quantize method def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str): if isinstance(quant_method, UnquantizedLinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = torch.Tensor() fused_params[name + '_bias'] = None fused_params[name + '_qzeros'] = None elif isinstance(quant_method, Fp8LinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = layer.weight_scale_inv fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, GPTQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, AWQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros else: raise ValueError(f"Unsupported quant_method: {quant_method}") class Qwen3Attention(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None # new added params ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata_all = forward_context.attn_metadata kv_cache = self.attn.kv_cache[forward_context.virtual_engine] # reshape kvcache num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size) kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim) if isinstance(attn_metadata_all, dict): attn_metadata = attn_metadata_all.items().__iter__().__next__()[1] is_decode = attn_metadata.prefill_metadata is None else: is_decode = attn_metadata_all.prefill_metadata is None attn_metadata = attn_metadata_all reduce_result = is_decode # total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size # # only support 4M now # if total_bytes < 4194304: # reduce_result = True if USE_FUSED_QWEN_ATTENTION: if is_decode: positions = [i - 1 for i in attn_metadata.seq_lens] cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions] sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions] else: cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens] sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens] if residual is None: res_out = hidden_states #from torch_vacc.vacc import fuse_atten_qwen3 attn_outs = None if not is_decode: from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler if memory_recycler is not None: attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER total_num_kv_heads = self.total_num_kv_heads if self.total_num_kv_heads < get_tp_group().world_size: assert get_tp_group().world_size % self.total_num_kv_heads == 0 total_num_kv_heads = get_tp_group().world_size attn_outs = torch.vacc.fuse_atten_qwen3( # attn_outs = vacc_fused_attn_qwen3_naive( hidden_states=hidden_states, residual=residual, hidden_states_norm_weight=self.fused_params['input_layernorm_weight'], qkv_proj_weight=self.fused_params['qkv_proj_weight'], qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'], qkv_proj_bias=self.fused_params['qkv_proj_bias'], qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'], q_layernorm_weight=self.fused_params['q_norm_weight'], k_layernorm_weight=self.fused_params['k_norm_weight'], sin_cache=sin_cache, cos_cache=cos_cache, slot_mapping=attn_metadata.slot_mapping, kv_cache=kv_cache, block_tables=attn_metadata.block_tables, # tensor block_group_size=env_blk_grp_size, o_proj_weight=self.fused_params['o_proj_weight'], o_proj_weight_scale=self.fused_params['o_proj_weight_scale'], o_proj_bias=self.fused_params['o_proj_bias'], o_proj_qzeros=self.fused_params['o_proj_qzeros'], seq_lens=attn_metadata.seq_lens, sm_scale=self.scaling, num_attention_heads=self.total_num_heads, num_key_value_heads=total_num_kv_heads, flash_attention=is_decode, # decode use flash_atten by default is_decode=is_decode, reduce_result=reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos, output_opt=attn_outs, res_opt=residual) if residual is None: attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs else: res_out = attn_outs[1] attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0] return attn_out, res_out else: # orig code qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class Qwen3DecoderLayer(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention # NOTE: input_layernorm is fused in vacc_fused_attn_qwen3 if USE_FUSED_QWEN_ATTENTION: if not hasattr(self.self_attn, "fused_params"): self.self_attn.fused_params = {} self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj') set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj') hidden_states, residual = self.self_attn( positions=positions, hidden_states=hidden_states, residual=residual) else: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual