[2/N][Refactor][Qwen3-Next] remove redundant methods and patch methods in Qwen3NextGatedDeltaNet (#3082)

### What this PR does / why we need it?
remove redundant methods and patch methods in Qwen3NextGatedDeltaNet
involved causal_conv1d_fn, causal_conv1d_update_npu, fused_gdn_gating,
fused_reccrrent_gated_delta_rule, torch_chunk_gated_delta_rule,
RMSNormGated

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
```
def main():
    prompts = [
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        model="Qwen/Qwen3-Next-80B-A3B-Instruct",
              tensor_parallel_size=4,
              enforce_eager=True,
              trust_remote_code=True,
              max_model_len=256,
              gpu_memory_utilization=0.7,
              block_size=64,
              )
    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

CI passed with new added/existing test.


- vLLM version: v0.10.2
- vLLM main:
5aeb925452

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-24 11:25:42 +08:00
committed by GitHub
parent eb205d9f35
commit e7618d9414
6 changed files with 667 additions and 980 deletions

View File

@@ -6,7 +6,6 @@ from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
@@ -19,6 +18,10 @@ from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fla.ops import RMSNormGated
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.fused_moe import FusedMoE
# yapf conflicts with isort for this block
# yapf: disable
@@ -34,6 +37,8 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \
mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -45,7 +50,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
Qwen3NextSparseMoeBlock)
Qwen3NextSparseMoeBlock,
fused_gdn_gating)
from vllm.model_executor.models.utils import (
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
@@ -57,108 +63,6 @@ from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
causal_conv1d_update_npu)
from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = F.normalize(query, p=2, dim=-1)
key = F.normalize(key, p=2, dim=-1)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
scale = 1 / (query.shape[-1]**0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) -
g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -(
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
k_head_dim, v_head_dim).to(value) if
initial_state is None else initial_state.to(value))
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=1)
# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) *
decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
(k_i *
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2) @ v_new)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
core_attn_out.shape[1], -1,
core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out.transpose(1,
2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
@@ -275,6 +179,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
norm_before_gate=True,
device="npu",
)
self.out_proj = RowParallelLinear(self.value_dim,
@@ -467,7 +373,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
query_start_loc=non_spec_query_start_loc,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update_npu(
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
@@ -551,7 +457,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
(
cur_core_attn_out_non_spec,
cur_last_recurrent_state,
) = torch_chunk_gated_delta_rule(
) = chunk_gated_delta_rule(
query=cur_q,
key=cur_k,
value=cur_v,