From 14b39d3c700c543e0f47bd5ccf1fb72ccae98c71 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Mon, 22 Sep 2025 11:24:08 +0800 Subject: [PATCH] [1/N][Refactor][Qwen3-Next] remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention (#3019) ### What this PR does / why we need it? remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` def main(): prompts = [ "The future of AI is", ] sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( # model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-30B-A3B", model="Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9d1c50a5ac8726f4af0d4a4e85ad4d26a674ad26 --------- Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/models/qwen3_next.py | 247 ++----------------------------- 1 file changed, 10 insertions(+), 237 deletions(-) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 6234c83..1606f61 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -11,11 +11,11 @@ from einops import rearrange from torch import nn from transformers.activations import ACT2FN from vllm import envs -from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.attention import AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import (divide, get_ep_group, get_pp_group, +from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context @@ -27,8 +27,6 @@ from vllm.model_executor.layers.layernorm import \ # yapf: enable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -37,10 +35,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import \ - GPTQMarlinConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -50,6 +44,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention, + Qwen3NextSparseMoeBlock) from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, @@ -68,112 +64,6 @@ from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule -class Qwen3NextSparseMoeBlock(nn.Module): - - def __init__( - self, - config: Qwen3NextConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - enable_eplb: bool = False, - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - - self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() - self.ep_size = self.ep_group.size() - self.n_routed_experts = config.num_experts - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb - - self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) - self.n_local_physical_experts = self.n_physical_experts // self.ep_size - - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) - - self.experts = FusedMoE(num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=f"{prefix}.gate") - - if config.shared_expert_intermediate_size > 0: - self.shared_expert = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.shared_expert_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), - ) - else: - self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) - - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid gate quantization. - # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) - - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) - - return final_hidden_states.view(orig_shape) - - def torch_chunk_gated_delta_rule( query, key, @@ -473,7 +363,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): output: torch.Tensor, cache_params: Optional[MambaCacheParams] = None, ): - return torch.ops.vllm.gdn_attention( + return torch.ops.vllm.npu_gdn_attention( hidden_states, output, self.prefix, @@ -737,123 +627,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): output[:num_actual_tokens], _ = self.out_proj(core_attn_out) -class Qwen3NextAttention(nn.Module): - - def __init__( - self, - config: Qwen3NextConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.dual_chunk_attention_config = getattr( - config, "dual_chunk_attention_config", None) - self.attn_output_gate = getattr(config, "attn_output_gate", True) - - self.qkv_proj = QKVParallelLinear( - config.hidden_size, - self.head_dim, - self.total_num_heads * (1 + self.attn_output_gate), - self.total_num_kv_heads, - bias=getattr(config, "qkv_bias", False), - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - head_size=self.head_dim, - rotary_dim=self.head_dim, - max_position=config.max_position_embeddings, - base=config.rope_theta, - rope_scaling=config.rope_scaling, - partial_rotary_factor=config.partial_rotary_factor, - dual_chunk_attention_config=self.dual_chunk_attention_config, - ) - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": - self.dual_chunk_attention_config, - } if self.dual_chunk_attention_config else {}, - ) - - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - output: torch.Tensor, - hidden_states: torch.Tensor, - ): - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn_output_gate: - q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) - orig_shape = q_gate.shape[:-1] - q_gate = q_gate.view(*orig_shape, self.num_heads, -1) - q, gate = torch.chunk(q_gate, 2, dim=-1) - q = q.reshape(*orig_shape, -1) - gate = gate.reshape(*orig_shape, -1) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) - - q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( - -1, self.num_heads * self.head_dim) - k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( - -1, self.num_kv_heads * self.head_dim) - - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn(q, k, v) - - if self.attn_output_gate: - gate = torch.sigmoid(gate) - attn_output = attn_output * gate - - output[:], _ = self.o_proj(attn_output) - - class Qwen3NextDecoderLayer(nn.Module): def __init__( @@ -1325,7 +1098,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, return self.model.get_expert_mapping() -def gdn_attention( +def npu_gdn_attention( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, @@ -1335,7 +1108,7 @@ def gdn_attention( self._forward(hidden_states=hidden_states, output=output) -def gdn_attention_fake( +def npu_gdn_attention_fake( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, @@ -1344,9 +1117,9 @@ def gdn_attention_fake( direct_register_custom_op( - op_name="gdn_attention", - op_func=gdn_attention, + op_name="npu_gdn_attention", + op_func=npu_gdn_attention, mutates_args=["output"], - fake_impl=gdn_attention_fake, + fake_impl=npu_gdn_attention_fake, dispatch_key=current_platform.dispatch_key, )