[1/N][Refactor][Qwen3-Next] remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention (#3019)

### What this PR does / why we need it?
remove redundant Qwen3NextSparseMoeBlock and Qwen3NextAttention

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
```
def main():
    prompts = [
        "The future of AI is",
    ]

    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        # model="/root/.cache/modelscope/hub/models/Qwen/Qwen3-30B-A3B",
        model="Qwen/Qwen3-Next-80B-A3B-Instruct",
              tensor_parallel_size=4,
              enforce_eager=True,
              trust_remote_code=True,
              max_model_len=256,
              gpu_memory_utilization=0.7,
              block_size=64,
              )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- vLLM version: v0.10.2
- vLLM main:
9d1c50a5ac

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-22 11:24:08 +08:00
committed by GitHub
parent 88d24cce8b
commit 14b39d3c70

View File

@@ -11,11 +11,11 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.attention import AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group,
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
@@ -27,8 +27,6 @@ from vllm.model_executor.layers.layernorm import \
# yapf: enable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
@@ -37,10 +35,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import \
GPTQMarlinConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
@@ -50,6 +44,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
Qwen3NextSparseMoeBlock)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
@@ -68,112 +64,6 @@ from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.gate")
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen3NextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
1,
bias=False)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
return final_hidden_states.view(orig_shape)
def torch_chunk_gated_delta_rule(
query,
key,
@@ -473,7 +363,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
output: torch.Tensor,
cache_params: Optional[MambaCacheParams] = None,
):
return torch.ops.vllm.gdn_attention(
return torch.ops.vllm.npu_gdn_attention(
hidden_states,
output,
self.prefix,
@@ -737,123 +627,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class Qwen3NextAttention(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None)
self.attn_output_gate = getattr(config, "attn_output_gate", True)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads * (1 + self.attn_output_gate),
self.total_num_kv_heads,
bias=getattr(config, "qkv_bias", False),
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
rope_scaling=config.rope_scaling,
partial_rotary_factor=config.partial_rotary_factor,
dual_chunk_attention_config=self.dual_chunk_attention_config,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config":
self.dual_chunk_attention_config,
} if self.dual_chunk_attention_config else {},
)
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
output: torch.Tensor,
hidden_states: torch.Tensor,
):
qkv, _ = self.qkv_proj(hidden_states)
if self.attn_output_gate:
q_gate, k, v = qkv.split(
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
orig_shape = q_gate.shape[:-1]
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
q, gate = torch.chunk(q_gate, 2, dim=-1)
q = q.reshape(*orig_shape, -1)
gate = gate.reshape(*orig_shape, -1)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
-1, self.num_heads * self.head_dim)
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
-1, self.num_kv_heads * self.head_dim)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if self.attn_output_gate:
gate = torch.sigmoid(gate)
attn_output = attn_output * gate
output[:], _ = self.o_proj(attn_output)
class Qwen3NextDecoderLayer(nn.Module):
def __init__(
@@ -1325,7 +1098,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
return self.model.get_expert_mapping()
def gdn_attention(
def npu_gdn_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
@@ -1335,7 +1108,7 @@ def gdn_attention(
self._forward(hidden_states=hidden_states, output=output)
def gdn_attention_fake(
def npu_gdn_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
@@ -1344,9 +1117,9 @@ def gdn_attention_fake(
direct_register_custom_op(
op_name="gdn_attention",
op_func=gdn_attention,
op_name="npu_gdn_attention",
op_func=npu_gdn_attention,
mutates_args=["output"],
fake_impl=gdn_attention_fake,
fake_impl=npu_gdn_attention_fake,
dispatch_key=current_platform.dispatch_key,
)