Files
2026-04-02 04:55:00 +00:00

790 lines
40 KiB
Python

"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List
import itertools
import os
import torch
from torch import nn
from transformers import PretrainedConfig
from torch_vacc.vacc.custom_ops_cpu import (
w8a8_block_fp8_linear as w8a8_block_fp8_linear_cpu,
)
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock, Qwen3MoeMLP
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
from vllm.model_executor.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from ..ops.mrope_op import get_sin_cos_mrope
from ..ops.qwen3_fused_moe import vacc_fused_prefill_moe_fp8, vacc_fused_decode_moe_fp8, recompute_moe_layer_blocksize
from .vars import *
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
logger = init_logger(__name__)
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = None
fused_params[name + '_bias'] = None
fused_params[name + '_qzeros'] = None
if isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
def apply_w8a8_block_fp8_linear_v2(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_scale = None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
block_size = [
weight.shape[-2] // weight_scale.shape[-2],
weight.shape[-1] // weight_scale.shape[-1],
]
if input.device.type == "vacc":
output = torch.vacc.w8a8_block_fp8_linear(
input_2d, weight, input_scale, weight_scale, block_size
)
else:
output = w8a8_block_fp8_linear_cpu(
input_2d, weight, input_scale, weight_scale, block_size
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def vacc_fused_attn_qwen3_naive(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale: torch.Tensor,
qkv_proj_bias: Optional[torch.Tensor],
qkv_proj_qzeros: Optional[torch.Tensor],
q_layernorm_weight: torch.Tensor,
k_layernorm_weight: torch.Tensor,
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale: torch.Tensor,
o_proj_bias: Optional[torch.Tensor],
o_proj_qzeros: Optional[torch.Tensor],
seq_lens: List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attention: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
block_size: int = 16
):
if residual is not None:
hidden_states = hidden_states + residual
residual_out = hidden_states
hidden_states = torch.vacc.rms_norm(
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
head_dim = 128
# qkv gen
qkv = apply_w8a8_block_fp8_linear_v2(
input=hidden_states,
weight=qkv_proj_weight,
weight_scale=qkv_proj_weight_scale)
num_q_heads = num_attention_heads // world_size
num_kv_heads = num_key_value_heads // world_size
q_size = head_dim * num_q_heads
kv_size = head_dim * num_kv_heads
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
# q_by_head = self.q_norm.forward_native(q_by_head)
q_norm = torch.vacc.rms_norm(q_by_head, q_layernorm_weight, 1e-6)
# q = q_by_head.view(q.shap
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
# k_by_head = k_norm.forward_native(k_by_head)
k_norm = torch.vacc.rms_norm(k_by_head, k_layernorm_weight, 1e-6)
# k = k_by_head.view(k.shap
v = v.view(-1, num_kv_heads, head_dim)
# q, k = self.rotary_emb(positions, q, k)
start = 0
attn_outs = []
if is_decode:
# convert block_tables to 8K group index
block_per_group = block_group_size // block_size
block_tables = (block_tables // block_per_group).to(torch.int32)
# logger.warning(f"decode block table: {block_tables}")
num_blocks = kv_cache.shape[1]
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
# bs loop
for i in range(len(seq_lens)):
if not is_decode:
# prefill
end = start + seq_lens[i]
else:
# decode
end = start + 1
cos = cos_cache[i].unsqueeze(-2)
sin = sin_cache[i].unsqueeze(-2)
q, k = torch.vacc.RotaryPosEmbedding(
q_norm[start : end, ...], k_norm[start : end, ...], cos, sin, 0, "neox")
# cache concat
torch.vacc.reshape_and_cache_attention(k, key_cache_split, slot_mapping[start : end, ...])
torch.vacc.reshape_and_cache_attention(v[start : end, ...], value_cache_split, slot_mapping[start : end, ...])
# attn_output = self.attn(q, k, v)
if not is_decode:
# prefill
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v[start : end, ...],
attn_mask = None,
dropout_p = 0.0,
is_causal = True, #causal_attn and not self.need_mask,
is_train = False,
recompute = False,
flash_attention = False,
sm_scale=sm_scale)
else:
# decode
key_cache = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
value_cache = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
k_slices = key_cache[block_tables[i], ...]
k_cached = torch.cat(
[k_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
k_cached = k_cached.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_tables[i], ...]
v_cached = torch.cat(
[v_slices[i].unsqueeze(1) for i in range(len(block_tables[i]))],
dim=0,
)
v_cached = v_cached.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=q,
key=k_cached,
value=v_cached,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,#flash_attention,
sm_scale=sm_scale)
attn_outs.append(attn_out)
# update start
start = end
attn_out = torch.cat(attn_outs, dim=0)
# output, _ = self.o_proj(attn_output)
o_proj = apply_w8a8_block_fp8_linear_v2(
input = attn_out.reshape(hidden_states.shape[0], -1),
weight = o_proj_weight,
weight_scale = o_proj_weight_scale,
)
if reduce_result:
o_proj = tensor_model_parallel_all_reduce(o_proj)
if residual is not None:
return o_proj, residual_out
return o_proj
def Qwen3MoeSparseMoeBlock__init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super(Qwen3MoeSparseMoeBlock, self).__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
#patch here to transpose w2/w2_scale's data arrange , only for block quant
if hasattr(self.experts.quant_method, 'quant_config') and hasattr(self.experts.quant_method.quant_config, 'weight_block_size'):
self.experts.w2_weight.data = self.experts.w2_weight.data.transpose(-1,-2).contiguous().transpose(-1,-2)
self.experts.w2_weight_scale_inv.data = self.experts.w2_weight_scale_inv.data.transpose(-1,-2).contiguous().transpose(-1,-2)
if hasattr(self.experts, 'w2_weight_scale_inv_prefill'):
self.experts.w2_weight_scale_inv_prefill.data = self.experts.w2_weight_scale_inv_prefill.data.transpose(-1,-2).contiguous().transpose(-1,-2)
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
positions: Union[torch.Tensor, list],
is_decode: bool):
if isinstance(rotary_emb, MRotaryEmbedding):
# get mrope sin/cos
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
if len(attn_metadata.seq_lens) > 1:
if is_decode:
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
else:
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
else:
cos_cache = [cos_cache]
sin_cache = [sin_cache]
else:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
return cos_cache, sin_cache
class Qwen3MoeDecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
cos_cache: list[torch.Tensor],
sin_cache: list[torch.Tensor]
) -> torch.Tensor:
# Self Attention
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
# NOTE: input_layernorm is fused in vacc_fused_attn_qwen3
if USE_FUSED_QWEN_ATTENTION:
if not hasattr(self.self_attn, "fused_params"):
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
self.self_attn.fused_params['q_norm_weight'] = self.self_attn.q_norm.weight
self.self_attn.fused_params['k_norm_weight'] = self.self_attn.k_norm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual,
cos_cache=cos_cache,
sin_cache=sin_cache)
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
cos_cache=cos_cache,
sin_cache=sin_cache
)
# # Fully Connected
# hidden_states, residual = self.post_attention_layernorm(
# hidden_states, residual)
# hidden_states = self.mlp(hidden_states)
# return hidden_states, residual
# TODO for noquant or not block_quant
if not hasattr(self.mlp.experts.quant_method, 'quant_config') or \
not hasattr(self.mlp.experts.quant_method.quant_config, 'weight_block_size'):
if not isinstance(self.mlp.experts.quant_method, MoeWNA16Method):
logger.warning('TODO for noquant or other quant')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
if isinstance(attn_metadata, dict):
# is_prefill = get_forward_context().attn_metadata['test'].prefill_metadata
attn_metadata_0 = get_forward_context().attn_metadata.items().__iter__().__next__()[1]
is_prefill = attn_metadata_0.prefill_metadata
else:
is_prefill = get_forward_context().attn_metadata.prefill_metadata
quant_method = self.mlp.experts.quant_method if isinstance(self.mlp, Qwen3MoeSparseMoeBlock) \
else self.mlp.down_proj.quant_method
if is_prefill is not None:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_prefill_moe_gptq_int4
return vacc_fused_prefill_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_gptq_int4 fail: {e}')
else:
recompute_moe_layer_blocksize(self.mlp.experts)
try:
return vacc_fused_prefill_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_prefill_moe_fp8 fail: {e}')
else:
if isinstance(quant_method, MoeWNA16Method):
try:
from vllm_vacc.vllm.model_executor.ops.qwen3_fused_moe import vacc_fused_decode_moe_gptq_int4
return vacc_fused_decode_moe_gptq_int4(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_gptq_int4 fail: {e}')
else:
try:
return vacc_fused_decode_moe_fp8(hidden_states,
residual,
self.post_attention_layernorm,
self.mlp.gate,
self.mlp.experts)
except Exception as e:
print(f'vacc_fused_decode_moe_fp8 fail: {e}')
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3MoeAttention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, # new added params
cos_cache: list[torch.Tensor] = None,
sin_cache: list[torch.Tensor] = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# reshape kvcache
num_kv_heads = max(1, self.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, self.head_dim)
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
reduce_result = is_decode
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# # only support 4M now
# if total_bytes < 4194304:
# reduce_result = True
if USE_FUSED_QWEN_ATTENTION:
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
if residual is None:
res_out = hidden_states
#from torch_vacc.vacc import fuse_atten_qwen3
attn_outs = None
if not is_decode:
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
if memory_recycler is not None:
attn_outs = memory_recycler.MLA_OPROJ_OUT_BUFFER
total_num_kv_heads = self.total_num_kv_heads
if self.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % self.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
attn_outs = torch.vacc.fuse_atten_qwen3(
# attn_outs = vacc_fused_attn_qwen3_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
q_layernorm_weight=self.fused_params['q_norm_weight'],
k_layernorm_weight=self.fused_params['k_norm_weight'],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=self.fused_params['o_proj_bias'],
o_proj_qzeros=self.fused_params['o_proj_qzeros'],
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attention=is_decode, # decode use flash_atten by default
is_decode=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos,
output_opt=attn_outs,
res_opt=residual)
# debug_qwen3_moe_attention_prefill(hidden_states=hidden_states,
# residual=residual,
# attn_outs=attn_outs,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
return attn_out, res_out
else:
# orig code
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm.forward_native(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm.forward_native(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Qwen3MoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if not hasattr(self, "weight_capture"):
from vllm_vacc.vllm.model_executor.models.weight_capture.qwen3_moe_weight_capture import Qwen3Moe_WeightCapture
self.weight_capture = Qwen3Moe_WeightCapture(self.layers, self.start_layer, self.end_layer)
self.layer_nums = self.end_layer - self.start_layer
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# fused layer decoder only support fp8 quant model now
use_default_layer = self.weight_capture.support_fused_weights and USE_DECODER_LAYER_FUSE_MODE
# print('Qwen3MoeModel attn_metadata', attn_metadata)
if isinstance(attn_metadata_all, dict):
# is_decode = attn_metadata_all['test'].prefill_metadata is None
# attn_metadata = attn_metadata_all['test']
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
is_decode = attn_metadata.prefill_metadata is None
else:
is_decode = attn_metadata_all.prefill_metadata is None
attn_metadata = attn_metadata_all
if(use_default_layer and is_decode):
from torch_vacc.vacc.custom_ops import qwen3_fuse_attention_moe_decode
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode=True)
for i in range(0, self.layer_nums):
layer = self.layers[i + self.start_layer]
kv_cache = layer.self_attn.attn.kv_cache[forward_context.virtual_engine]
num_kv_heads = max(1, layer.self_attn.total_num_kv_heads // get_tp_group().world_size)
kv_cache = kv_cache.view(2, -1, 16, num_kv_heads, layer.self_attn.head_dim)
total_num_kv_heads = layer.self_attn.total_num_kv_heads
if layer.self_attn.total_num_kv_heads < get_tp_group().world_size:
assert get_tp_group().world_size % layer.self_attn.total_num_kv_heads == 0
total_num_kv_heads = get_tp_group().world_size
hidden_states, residual = qwen3_fuse_attention_moe_decode(hidden_states, residual,
hidden_states_norm_weight=self.weight_capture.layer_mapper.attn_args._0_input_layernorm_weight[i],
qkv_proj_weight=self.weight_capture.layer_mapper.attn_args._1_qkv_proj_weight[i],
qkv_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._2_qkv_proj_weight_scale[i],
qkv_proj_bias=self.weight_capture.layer_mapper.attn_args._3_qkv_proj_bias[i],
qkv_proj_qzeros=self.weight_capture.layer_mapper.attn_args._4_qkv_proj_qzeros[i],
q_layernorm_weight=self.weight_capture.layer_mapper.attn_args._5_q_norm_weight[i],
k_layernorm_weight=self.weight_capture.layer_mapper.attn_args._6_k_norm_weight[i],
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.weight_capture.layer_mapper.attn_args._13_o_proj_weight[i],
o_proj_weight_scale_inv=self.weight_capture.layer_mapper.attn_args._14_o_proj_weight_scale[i],
o_proj_bias=self.weight_capture.layer_mapper.attn_args._15_o_proj_bias[i],
o_proj_qzeros=self.weight_capture.layer_mapper.attn_args._16_o_proj_qzeros[i],
seq_lens_num=attn_metadata.seq_lens,
sm_scale=layer.self_attn.scaling,
num_attention_heads=layer.self_attn.total_num_heads,
num_key_value_heads=total_num_kv_heads,
flash_attentiton=True,
is_decode=True,
reduce_result=True,
# moe
rms_weight=self.weight_capture.layer_mapper.moe_args._0_rms_norm_weight[i],
moe_weight_13=self.weight_capture.layer_mapper.moe_args._1_w13_weight[i],
moe_weight_2=self.weight_capture.layer_mapper.moe_args._2_w2_weight[i],
moe_weight_13_dequat=self.weight_capture.layer_mapper.moe_args._3_w13_weight_scale_inv[i],
moe_weight_2_dequant=self.weight_capture.layer_mapper.moe_args._4_w2_weight_scale_inv[i],
gate_weight=self.weight_capture.layer_mapper.moe_args._5_gate_weight[i],
block_size_13=self.weight_capture.layer_mapper.moe_args._6_w13_block_size,
block_size_2=self.weight_capture.layer_mapper.moe_args._7_w2_block_size,
# dist
world_size=self.weight_capture.layer_mapper.dist_args._0_world_size,
rank=self.weight_capture.layer_mapper.dist_args._1_rank,
group_id=self.weight_capture.layer_mapper.dist_args._2_group_id,
dev_info=self.weight_capture.layer_mapper.dist_args._3_dev_info)
else:
layer0 = self.layers[self.start_layer]
cos_cache, sin_cache = get_cos_sin_cache(layer0.self_attn.rotary_emb, attn_metadata, positions, is_decode)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual, cos_cache, sin_cache )
if deepstack_input_embeds is not None and i in range(0, len(deepstack_input_embeds)):
if isinstance(deepstack_input_embeds, IntermediateTensors):
hidden_states = hidden_states + deepstack_input_embeds[f"deepstack_input_embeds_{i}"]
elif isinstance(deepstack_input_embeds, torch.Tensor):
hidden_states = hidden_states + deepstack_input_embeds[i]
else:
raise ValueError(f'unsupported type: {type(deepstack_input_embeds)}')
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states, residual)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds = None,
) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata.items().__iter__().__next__()[1]
if attn_metadata.prefill_metadata is not None:
from .memory.memory_recycling import alloc_memory_recycler
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
if hasattr(attn_metadata, 'num_prefill_tokens'):
tokens = attn_metadata.num_prefill_tokens
else:
tokens = attn_metadata.prefill_metadata.num_prefill_tokens
vllm_model_mode = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
vllm_model_mode = config_infos
if get_tp_group().rank_in_group == 0:
memory_infos = f'[MemoryRecycler] enable: {vllm_model_mode}'
logger.info(memory_infos)
if not alloc_memory_recycler(tokens, vllm_model=vllm_model_mode, world_size=get_tp_group().world_size, dtype=self.lm_head.weight.dtype):
logger.warning("deepseek memory recycler allock fail. current request may inefficient %s", tokens)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, deepstack_input_embeds)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
from .memory.memory_recycling import init_huge_memory_allocator
from .vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.config_manager import vllm_vacc_config_manager
# default is deepseek, config can set to ['deepseek_mtp',]
model_name = "qwen3_moe"
config_infos = vllm_vacc_config_manager().get_model_infos()
if config_infos != "default":
model_name = config_infos
if not init_huge_memory_allocator(LLM_MAX_PREFILL_SEQ_LEN, self.config.hidden_size, vllm_model=model_name):
logger.warning("init huge memory allocator fail. prefill memory recycling will disable")
from vllm.model_executor.models.utils import AutoWeightsLoader
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)