init
This commit is contained in:
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
194
vllm_vacc/vllm/model_executor/models/qwen3.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union, Any, Dict
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.logger import init_logger
|
||||
from .vars import *
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod as UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = torch.Tensor()
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
elif isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None # new added params
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables, # tensor
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
Reference in New Issue
Block a user