194 lines
9.5 KiB
Python
194 lines
9.5 KiB
Python
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
|
|
from collections.abc import Iterable
|
|
from typing import Optional, Union, Any, Dict
|
|
import torch
|
|
from torch import nn
|
|
from vllm.logger import init_logger
|
|
from .vars import *
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
|
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
|
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# uniform the params names from different quantize method
|
|
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
|
if isinstance(quant_method, UnquantizedLinearMethod):
|
|
fused_params[name + '_weight'] = layer.weight
|
|
fused_params[name + '_weight_scale'] = torch.Tensor()
|
|
fused_params[name + '_bias'] = None
|
|
fused_params[name + '_qzeros'] = None
|
|
elif isinstance(quant_method, Fp8LinearMethod):
|
|
fused_params[name + '_weight'] = layer.weight
|
|
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
elif isinstance(quant_method, GPTQLinearMethod):
|
|
fused_params[name + '_weight'] = layer.qweight
|
|
fused_params[name + '_weight_scale'] = layer.scales
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
elif isinstance(quant_method, AWQLinearMethod):
|
|
fused_params[name + '_weight'] = layer.qweight
|
|
fused_params[name + '_weight_scale'] = layer.scales
|
|
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
|
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
|
else:
|
|
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
|
|
|
class Qwen3Attention(nn.Module):
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None # new added params
|
|
) -> torch.Tensor:
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata_all = forward_context.attn_metadata
|
|
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
|
|
|
# reshape kvcache
|
|
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
|
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
|
|
|
if isinstance(attn_metadata_all, dict):
|
|
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
|
is_decode = attn_metadata.prefill_metadata is None
|
|
else:
|
|
is_decode = attn_metadata_all.prefill_metadata is None
|
|
attn_metadata = attn_metadata_all
|
|
|
|
reduce_result = is_decode
|
|
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
|
# # only support 4M now
|
|
# if total_bytes < 4194304:
|
|
# reduce_result = True
|
|
|
|
if USE_FUSED_QWEN_ATTENTION:
|
|
if is_decode:
|
|
positions = [i - 1 for i in attn_metadata.seq_lens]
|
|
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
|
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
|
else:
|
|
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
|
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
|
if residual is None:
|
|
res_out = hidden_states
|
|
#from torch_vacc.vacc import fuse_atten_qwen3
|
|
attn_outs = None
|
|
if not is_decode:
|
|
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
|
if memory_recycler is not None:
|
|
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
|
|
|
total_num_kv_heads = self.total_num_kv_heads
|
|
if self.total_num_kv_heads < get_tp_group().world_size:
|
|
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
|
total_num_kv_heads = get_tp_group().world_size
|
|
|
|
attn_outs = torch.vacc.fuse_atten_qwen3(
|
|
# attn_outs = vacc_fused_attn_qwen3_naive(
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
|
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
|
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
|
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
|
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
|
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
|
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
|
sin_cache=sin_cache,
|
|
cos_cache=cos_cache,
|
|
slot_mapping=attn_metadata.slot_mapping,
|
|
kv_cache=kv_cache,
|
|
block_tables=attn_metadata.block_tables, # tensor
|
|
block_group_size=env_blk_grp_size,
|
|
o_proj_weight=self.fused_params['o_proj_weight'],
|
|
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
|
o_proj_bias=self.fused_params['o_proj_bias'],
|
|
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
|
seq_lens=attn_metadata.seq_lens,
|
|
sm_scale=self.scaling,
|
|
num_attention_heads=self.total_num_heads,
|
|
num_key_value_heads=total_num_kv_heads,
|
|
flash_attention=is_decode, # decode use flash_atten by default
|
|
is_decode=is_decode,
|
|
reduce_result=reduce_result,
|
|
world_size=get_tp_group().world_size,
|
|
rank=get_tp_group().rank_in_group,
|
|
group_id=get_tp_group().group_id,
|
|
dev_info=get_tp_group().rank_device_infos,
|
|
output_opt=attn_outs,
|
|
res_opt=residual)
|
|
|
|
if residual is None:
|
|
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
|
else:
|
|
res_out = attn_outs[1]
|
|
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
|
|
|
return attn_out, res_out
|
|
else:
|
|
# orig code
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
# Add qk-norm
|
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
|
self.head_dim)
|
|
q_by_head = self.q_norm(q_by_head)
|
|
q = q_by_head.view(q.shape)
|
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
|
self.head_dim)
|
|
k_by_head = self.k_norm(k_by_head)
|
|
k = k_by_head.view(k.shape)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class Qwen3DecoderLayer(nn.Module):
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
|
if USE_FUSED_QWEN_ATTENTION:
|
|
if not hasattr(self.self_attn, "fused_params"):
|
|
self.self_attn.fused_params = {}
|
|
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
|
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
|
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
|
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
|
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
|
|
|
hidden_states, residual = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual)
|
|
else:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual |