init
This commit is contained in:
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
790
vllm_vacc/vllm/model_executor/models/qwen3_moe.py
Normal file
@@ -0,0 +1,790 @@
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from torch_vacc.vacc.custom_ops_cpu import (
|
||||
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
|
||||
)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
|
||||
from ..ops.mrope_op import get_sin_cos_mrope
|
||||
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
|
||||
from .vars import *
|
||||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# uniform the params names from different quantize method
|
||||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = None
|
||||
fused_params[name + '_bias'] = None
|
||||
fused_params[name + '_qzeros'] = None
|
||||
if isinstance(quant_method, Fp8LinearMethod):
|
||||
fused_params[name + '_weight'] = layer.weight
|
||||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, GPTQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
elif isinstance(quant_method, AWQLinearMethod):
|
||||
fused_params[name + '_weight'] = layer.qweight
|
||||
fused_params[name + '_weight_scale'] = layer.scales
|
||||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||||
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear_v2(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_scale = None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
block_size = [
|
||||
weight.shape[-2] // weight_scale.shape[-2],
|
||||
weight.shape[-1] // weight_scale.shape[-1],
|
||||
]
|
||||
|
||||
if input.device.type == "vacc":
|
||||
output = torch.vacc.w8a8_block_fp8_linear(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
else:
|
||||
output = w8a8_block_fp8_linear_cpu(
|
||||
input_2d, weight, input_scale, weight_scale, block_size
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def vacc_fused_attn_qwen3_naive(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
hidden_states_norm_weight: torch.Tensor,
|
||||
qkv_proj_weight: torch.Tensor,
|
||||
qkv_proj_weight_scale: torch.Tensor,
|
||||
qkv_proj_bias: Optional[torch.Tensor],
|
||||
qkv_proj_qzeros: Optional[torch.Tensor],
|
||||
q_layernorm_weight: torch.Tensor,
|
||||
k_layernorm_weight: torch.Tensor,
|
||||
sin_cache: List[torch.Tensor],
|
||||
cos_cache: List[torch.Tensor],
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_group_size: int,
|
||||
o_proj_weight: torch.Tensor,
|
||||
o_proj_weight_scale: torch.Tensor,
|
||||
o_proj_bias: Optional[torch.Tensor],
|
||||
o_proj_qzeros: Optional[torch.Tensor],
|
||||
seq_lens: List[int],
|
||||
sm_scale: float,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
flash_attention: bool,
|
||||
is_decode: bool,
|
||||
reduce_result: bool,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
group_id: int,
|
||||
dev_info: List[int] | Tuple[int],
|
||||
block_size: int = 16
|
||||
):
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual_out = hidden_states
|
||||
|
||||
hidden_states = torch.vacc.rms_norm(
|
||||
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
|
||||
|
||||
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
|
||||
head_dim = 128
|
||||
|
||||
# qkv gen
|
||||
qkv = apply_w8a8_block_fp8_linear_v2(
|
||||
input=hidden_states,
|
||||
weight=qkv_proj_weight,
|
||||
weight_scale=qkv_proj_weight_scale)
|
||||
|
||||
num_q_heads = num_attention_heads // world_size
|
||||
num_kv_heads = num_key_value_heads // world_size
|
||||
|
||||
q_size = head_dim * num_q_heads
|
||||
kv_size = head_dim * num_kv_heads
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
# q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
|
||||
# q = q_by_head.view(q.shap
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
# k_by_head = k_norm.forward_native(k_by_head)
|
||||
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
|
||||
# k = k_by_head.view(k.shap
|
||||
|
||||
v = v.view(-1, num_kv_heads, head_dim)
|
||||
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
start = 0
|
||||
attn_outs = []
|
||||
|
||||
if is_decode:
|
||||
# convert block_tables to 8K group index
|
||||
block_per_group = block_group_size // block_size
|
||||
block_tables = (block_tables // block_per_group).to(torch.int32)
|
||||
# logger.warning(f"decode block table: {block_tables}")
|
||||
|
||||
num_blocks = kv_cache.shape[1]
|
||||
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
|
||||
|
||||
# bs loop
|
||||
for i in range(len(seq_lens)):
|
||||
if not is_decode:
|
||||
# prefill
|
||||
end = start + seq_lens[i]
|
||||
else:
|
||||
# decode
|
||||
end = start + 1
|
||||
|
||||
cos = cos_cache[i].unsqueeze(-2)
|
||||
sin = sin_cache[i].unsqueeze(-2)
|
||||
|
||||
q, k = torch.vacc.RotaryPosEmbedding(
|
||||
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
|
||||
|
||||
# cache concat
|
||||
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
|
||||
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
|
||||
|
||||
# attn_output = self.attn(q, k, v)
|
||||
if not is_decode:
|
||||
# prefill
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v[start : end, ...],
|
||||
attn_mask = None,
|
||||
dropout_p = 0.0,
|
||||
is_causal = True, #causal_attn and not self.need_mask,
|
||||
is_train = False,
|
||||
recompute = False,
|
||||
flash_attention = False,
|
||||
sm_scale=sm_scale)
|
||||
else:
|
||||
# decode
|
||||
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||||
|
||||
k_slices = key_cache[block_tables[i], ...]
|
||||
k_cached = torch.cat(
|
||||
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
||||
|
||||
v_slices = value_cache[block_tables[i], ...]
|
||||
v_cached = torch.cat(
|
||||
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
|
||||
dim=0,
|
||||
)
|
||||
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
||||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k_cached,
|
||||
value=v_cached,
|
||||
attn_mask=None,
|
||||
dropout_p=0,
|
||||
is_causal=False,
|
||||
is_train=False,
|
||||
recompute=False,
|
||||
flash_attention=False,#flash_attention,
|
||||
sm_scale=sm_scale)
|
||||
|
||||
attn_outs.append(attn_out)
|
||||
# update start
|
||||
start = end
|
||||
attn_out = torch.cat(attn_outs, dim=0)
|
||||
|
||||
# output, _ = self.o_proj(attn_output)
|
||||
o_proj = apply_w8a8_block_fp8_linear_v2(
|
||||
input = attn_out.reshape(hidden_states.shape[0], -1),
|
||||
weight = o_proj_weight,
|
||||
weight_scale = o_proj_weight_scale,
|
||||
)
|
||||
|
||||
if reduce_result:
|
||||
o_proj = tensor_model_parallel_all_reduce(o_proj)
|
||||
|
||||
if residual is not None:
|
||||
return o_proj, residual_out
|
||||
return o_proj
|
||||
|
||||
def Qwen3MoeSparseMoeBlock__init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(Qwen3MoeSparseMoeBlock, self).__init__()
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
#patch here to transpose w2/w2_scale's data arrange , only for block quant
|
||||
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
|
||||
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
|
||||
|
||||
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
|
||||
positions: Union[torch.Tensor, list],
|
||||
is_decode: bool):
|
||||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||||
# get mrope sin/cos
|
||||
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
|
||||
if len(attn_metadata.seq_lens) > 1:
|
||||
if is_decode:
|
||||
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
|
||||
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
|
||||
else:
|
||||
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
|
||||
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
|
||||
else:
|
||||
cos_cache = [cos_cache]
|
||||
sin_cache = [sin_cache]
|
||||
else:
|
||||
if is_decode:
|
||||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||||
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||||
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||||
else:
|
||||
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||||
return cos_cache, sin_cache
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
cos_cache: list[torch.Tensor],
|
||||
sin_cache: list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if not hasattr(self.self_attn, "fused_params"):
|
||||
self.self_attn.fused_params = {}
|
||||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||||
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
|
||||
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||||
|
||||
hidden_states, residual = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache)
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
cos_cache=cos_cache,
|
||||
sin_cache=sin_cache
|
||||
)
|
||||
|
||||
# # Fully Connected
|
||||
# hidden_states, residual = self.post_attention_layernorm(
|
||||
# hidden_states, residual)
|
||||
# hidden_states = self.mlp(hidden_states)
|
||||
# return hidden_states, residual
|
||||
|
||||
# TODO for noquant or not block_quant
|
||||
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
|
||||
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
|
||||
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
|
||||
logger.warning('TODO for noquant or other quant')
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
|
||||
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
|
||||
is_prefill = attn_metadata_0.prefill_metadata
|
||||
|
||||
else:
|
||||
is_prefill = get_forward_context().attn_metadata.prefill_metadata
|
||||
|
||||
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
|
||||
else self.mlp.down_proj.quant_method
|
||||
|
||||
if is_prefill is not None:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
|
||||
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
recompute_moe_layer_blocksize(self.mlp.experts)
|
||||
try:
|
||||
return vacc_fused_prefill_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
|
||||
else:
|
||||
if isinstance(quant_method, MoeWNA16Method):
|
||||
try:
|
||||
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
|
||||
return vacc_fused_decode_moe_gptq_int4(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
|
||||
else:
|
||||
try:
|
||||
return vacc_fused_decode_moe_fp8(hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm,
|
||||
self.mlp.gate,
|
||||
self.mlp.experts)
|
||||
except Exception as e:
|
||||
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None, # new added params
|
||||
cos_cache: list[torch.Tensor] = None,
|
||||
sin_cache: list[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
# reshape kvcache
|
||||
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
|
||||
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
|
||||
reduce_result = is_decode
|
||||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||||
# # only support 4M now
|
||||
# if total_bytes < 4194304:
|
||||
# reduce_result = True
|
||||
|
||||
if USE_FUSED_QWEN_ATTENTION:
|
||||
if cos_cache is None or sin_cache is None:
|
||||
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
if residual is None:
|
||||
res_out = hidden_states
|
||||
#from torch_vacc.vacc import fuse_atten_qwen3
|
||||
attn_outs = None
|
||||
if not is_decode:
|
||||
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
|
||||
if memory_recycler is not None:
|
||||
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
|
||||
|
||||
total_num_kv_heads = self.total_num_kv_heads
|
||||
if self.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % self.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
attn_outs = torch.vacc.fuse_atten_qwen3(
|
||||
# attn_outs = vacc_fused_attn_qwen3_naive(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||||
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||||
q_layernorm_weight=self.fused_params['q_norm_weight'],
|
||||
k_layernorm_weight=self.fused_params['k_norm_weight'],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||||
o_proj_bias=self.fused_params['o_proj_bias'],
|
||||
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
sm_scale=self.scaling,
|
||||
num_attention_heads=self.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attention=is_decode, # decode use flash_atten by default
|
||||
is_decode=is_decode,
|
||||
reduce_result=reduce_result,
|
||||
world_size=get_tp_group().world_size,
|
||||
rank=get_tp_group().rank_in_group,
|
||||
group_id=get_tp_group().group_id,
|
||||
dev_info=get_tp_group().rank_device_infos,
|
||||
output_opt=attn_outs,
|
||||
res_opt=residual)
|
||||
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
|
||||
# residual=residual,
|
||||
# attn_outs=attn_outs,
|
||||
# fused_params=self.fused_params,
|
||||
# attn_metadata=attn_metadata,
|
||||
# is_decode=is_decode,
|
||||
# sin_cache=sin_cache,
|
||||
# cos_cache=cos_cache,
|
||||
# kv_cache=kv_cache,
|
||||
# env_blk_grp_size=env_blk_grp_size,
|
||||
# scaling=self.scaling,
|
||||
# total_num_heads=self.total_num_heads,
|
||||
# total_num_kv_heads=self.total_num_kv_heads,
|
||||
# world_size=get_tp_group().world_size,
|
||||
# rank=get_tp_group().rank_in_group,
|
||||
# group_id=get_tp_group().group_id,
|
||||
# dev_info=get_tp_group().rank_device_infos)
|
||||
|
||||
if residual is None:
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||||
else:
|
||||
res_out = attn_outs[1]
|
||||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||||
|
||||
return attn_out, res_out
|
||||
else:
|
||||
# orig code
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Add qk-norm
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
||||
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||
self.head_dim)
|
||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
||||
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
class Qwen3MoeModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata_all = forward_context.attn_metadata
|
||||
if not hasattr(self, "weight_capture"):
|
||||
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
|
||||
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
|
||||
self.layer_nums = self.end_layer - self.start_layer
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# fused layer decoder only support fp8 quant model now
|
||||
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
|
||||
# print('Qwen3MoeModel attn_metadata', attn_metadata)
|
||||
if isinstance(attn_metadata_all, dict):
|
||||
# is_decode = attn_metadata_all['test'].prefill_metadata is None
|
||||
# attn_metadata = attn_metadata_all['test']
|
||||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||||
is_decode = attn_metadata.prefill_metadata is None
|
||||
|
||||
else:
|
||||
is_decode = attn_metadata_all.prefill_metadata is None
|
||||
attn_metadata = attn_metadata_all
|
||||
|
||||
if(use_default_layer and is_decode):
|
||||
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
|
||||
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
|
||||
|
||||
for i in range(0, self.layer_nums):
|
||||
layer = self.layers[i + self.start_layer]
|
||||
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
|
||||
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
|
||||
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
|
||||
total_num_kv_heads = layer.self_attn.total_num_kv_heads
|
||||
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
|
||||
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
|
||||
total_num_kv_heads = get_tp_group().world_size
|
||||
|
||||
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
|
||||
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
|
||||
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
|
||||
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
|
||||
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
|
||||
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
|
||||
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
|
||||
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
|
||||
sin_cache=sin_cache,
|
||||
cos_cache=cos_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_group_size=env_blk_grp_size,
|
||||
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
|
||||
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
|
||||
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
|
||||
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
|
||||
seq_lens_num=attn_metadata.seq_lens,
|
||||
sm_scale=layer.self_attn.scaling,
|
||||
num_attention_heads=layer.self_attn.total_num_heads,
|
||||
num_key_value_heads=total_num_kv_heads,
|
||||
flash_attentiton=True,
|
||||
is_decode=True,
|
||||
reduce_result=True,
|
||||
# moe
|
||||
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
|
||||
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
|
||||
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
|
||||
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
|
||||
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
|
||||
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
|
||||
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
|
||||
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
|
||||
# dist
|
||||
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
|
||||
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
|
||||
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
|
||||
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
|
||||
else:
|
||||
layer0 = self.layers[self.start_layer]
|
||||
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
|
||||
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
|
||||
if isinstance(deepstack_input_embeds, IntermediateTensors):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
|
||||
elif isinstance(deepstack_input_embeds, torch.Tensor):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[i]
|
||||
else:
|
||||
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
class Qwen3MoeForCausalLM(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
from .memory.memory_recycling import alloc_memory_recycler
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
if hasattr(attn_metadata, 'num_prefill_tokens'):
|
||||
tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
|
||||
|
||||
vllm_model_mode = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
vllm_model_mode = config_infos
|
||||
|
||||
if get_tp_group().rank_in_group == 0:
|
||||
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
|
||||
logger.info(memory_infos)
|
||||
|
||||
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
|
||||
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, deepstack_input_embeds)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
from .memory.memory_recycling import init_huge_memory_allocator
|
||||
from .vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
|
||||
|
||||
# default is deepseek, config can set to ['deepseek_mtp',]
|
||||
model_name = "qwen3_moe"
|
||||
config_infos = vllm_vacc_config_manager().get_model_infos()
|
||||
if config_infos != "default":
|
||||
model_name = config_infos
|
||||
|
||||
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
|
||||
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
Reference in New Issue
Block a user