Files
2026-04-02 04:55:00 +00:00

194 lines
9.5 KiB
Python

"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union, Any, Dict
import torch
from torch import nn
from vllm.logger import init_logger
from .vars import *
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = torch.Tensor()
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
elif isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
class Qwen3Attention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None # new added params
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables, # tensor
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3DecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual