Merge pull request #40 from ldh2020/v0.11.0dev

[Kernel] Optimize the performance of Qwen3-Next
This commit is contained in:
Xinyu Dong
2025-12-22 21:50:27 +08:00
committed by GitHub
3 changed files with 32 additions and 15 deletions

View File

@@ -3,7 +3,7 @@
"""Inference-only Qwen3Next model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional
from typing import Optional, Union
import torch
import torch.nn.functional as F
@@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
maybe_prefix)
from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
from typing import Optional, Union
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
import xtorch_ops
@torch.compile(dynamic=True, backend="aot_eager")
@@ -613,12 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 3.2: process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
if self.num_v_heads // self.num_k_heads > 1:
query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if non_spec_state_indices_tensor.shape[0] > 100:
initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous()
else:
initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ]
initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device)
for i in range(non_spec_state_indices_tensor.shape[0]):
initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]]
initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1)
initial_state = initial_state.transpose(-1, -2).contiguous()
(
core_attn_out_non_spec,
last_recurrent_state,
@@ -634,8 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
cu_seqlens=non_spec_query_start_loc,
)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype)
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
xtorch_ops.reshape_and_cache_flash(
last_recurrent_state,
last_recurrent_state,
cast_ssm_state,
cast_ssm_state,
non_spec_state_indices_tensor)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(

View File

@@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
)
return o, final_state
@@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule(
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state
return o, final_state

View File

@@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
if key_cache.is_contiguous():
tmp_block_tables = prefill_meta.block_tables
else:
tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next
xtorch_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache (block_num, head, block_size, dim)
@@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
is_causal=True,
is_prefix_cache=True,
block_table=prefill_meta.block_tables,
block_table=tmp_block_tables,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc,
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
@@ -782,4 +788,4 @@ def use_cascade_attention(
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
return cascade_time < flash_decoding_time