Merge pull request #2 from ldh2020/ldh2020-qwen3-next

[Model] Optimize the performance of Qwen3-Next
This commit is contained in:
ldh2020
2025-12-22 11:11:01 +08:00
committed by GitHub
3 changed files with 32 additions and 15 deletions

View File

@@ -3,7 +3,7 @@
"""Inference-only Qwen3Next model.""" """Inference-only Qwen3Next model."""
from collections.abc import Iterable from collections.abc import Iterable
from itertools import islice from itertools import islice
from typing import Optional from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
maybe_prefix) maybe_prefix)
from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
from typing import Optional, Union
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
import xtorch_ops
@torch.compile(dynamic=True, backend="aot_eager") @torch.compile(dynamic=True, backend="aot_eager")
@@ -613,12 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 3.2: process the remaining part # 3.2: process the remaining part
if attn_metadata.num_prefills > 0: if attn_metadata.num_prefills > 0:
initial_state = ssm_state[ if non_spec_state_indices_tensor.shape[0] > 100:
non_spec_state_indices_tensor].contiguous() initial_state = ssm_state[
initial_state[~has_initial_state, ...] = 0 non_spec_state_indices_tensor].contiguous()
if self.num_v_heads // self.num_k_heads > 1: else:
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ]
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device)
for i in range(non_spec_state_indices_tensor.shape[0]):
initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]]
initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1)
initial_state = initial_state.transpose(-1, -2).contiguous()
( (
core_attn_out_non_spec, core_attn_out_non_spec,
last_recurrent_state, last_recurrent_state,
@@ -634,8 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
) )
# Init cache # Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
ssm_state.dtype) last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
xtorch_ops.reshape_and_cache_flash(
last_recurrent_state,
last_recurrent_state,
cast_ssm_state,
cast_ssm_state,
non_spec_state_indices_tensor)
elif attn_metadata.num_decodes > 0: elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = ( core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule( fused_recurrent_gated_delta_rule(

View File

@@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices, h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
) )
return o, final_state return o, final_state
@@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule(
num_accepted_tokens, num_accepted_tokens,
use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel,
) )
return o, final_state return o, final_state

View File

@@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
if key_cache.is_contiguous():
tmp_block_tables = prefill_meta.block_tables
else:
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.prefill_attention( xtorch_ops.prefill_attention(
q=prefill_query, q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim) k=key_cache, # Key Cache (block_num, head, block_size, dim)
@@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
out=output[num_decode_tokens:attn_metadata.num_actual_tokens], out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True, is_causal=True,
is_prefix_cache=True, is_prefix_cache=True,
block_table=prefill_meta.block_tables, block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc, context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
@@ -782,4 +788,4 @@ def use_cascade_attention(
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding. # Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time return cascade_time < flash_decoding_time