From 58c1db5073a96987d616a0324ce041e69168a3f5 Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Sun, 21 Dec 2025 10:34:43 +0800 Subject: [PATCH 1/4] [Bugfix] fix the bug of the flash_attention in Qwen3-Next --- vllm_kunlun/v1/attention/backends/kunlun_attn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 612ca71..4f2555e 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -673,6 +673,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] + + if key_cache.is_contiguous(): + tmp_block_tables = prefill_meta.block_tables + else: + tmp_block_tables = prefill_meta.block_tables * 2 # only test in Qwen3-Next + xtorch_ops.prefill_attention( q=prefill_query, k=key_cache, # Key Cache (block_num, head, block_size, dim) @@ -680,7 +686,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): out=output[num_decode_tokens:attn_metadata.num_actual_tokens], is_causal=True, is_prefix_cache=True, - block_table=prefill_meta.block_tables, + block_table=tmp_block_tables, context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_xpu=prefill_meta.query_start_loc, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, @@ -782,4 +788,4 @@ def use_cascade_attention( flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time \ No newline at end of file + return cascade_time < flash_decoding_time From 004e164bdb1c8676abf0c59e9661f42daa616d8f Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Sun, 21 Dec 2025 11:18:00 +0800 Subject: [PATCH 2/4] [Kernel] Optimize the recurrent op --- vllm_kunlun/ops/fla/fused_recurrent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_kunlun/ops/fla/fused_recurrent.py b/vllm_kunlun/ops/fla/fused_recurrent.py index 143b6a0..3902bee 100644 --- a/vllm_kunlun/ops/fla/fused_recurrent.py +++ b/vllm_kunlun/ops/fla/fused_recurrent.py @@ -44,6 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function): h0_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + is_h0_transposed=True ) return o, final_state @@ -150,4 +151,4 @@ def fused_recurrent_gated_delta_rule( num_accepted_tokens, use_qk_l2norm_in_kernel, ) - return o, final_state \ No newline at end of file + return o, final_state From b97c78130042f6012e70d242fcaf79a8f12ce79e Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Sun, 21 Dec 2025 11:22:06 +0800 Subject: [PATCH 3/4] [Kernel] Optimize the recurrent op --- vllm_kunlun/models/qwen3_next.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index d8c0aac..07ef498 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -616,6 +616,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): initial_state = ssm_state[ non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 + initial_state = initial_state.transpose(-1, -2).contiguous() if self.num_v_heads // self.num_k_heads > 1: query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) @@ -634,6 +635,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cu_seqlens=non_spec_query_start_loc, ) # Init cache + last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous() ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( ssm_state.dtype) elif attn_metadata.num_decodes > 0: From 8261a09e2a138538d1075e03d00a6da92eead706 Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Sun, 21 Dec 2025 15:45:32 +0800 Subject: [PATCH 4/4] [Kernel] Optimize the selection and update OP of ssm state --- vllm_kunlun/models/qwen3_next.py | 34 ++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index 07ef498..fea85da 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -3,7 +3,7 @@ """Inference-only Qwen3Next model.""" from collections.abc import Iterable from itertools import islice -from typing import Optional +from typing import Optional, Union import torch import torch.nn.functional as F @@ -69,10 +69,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops - - -from typing import Optional, Union from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask +import xtorch_ops @torch.compile(dynamic=True, backend="aot_eager") @@ -613,13 +611,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 + if non_spec_state_indices_tensor.shape[0] > 100: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + else: + initial_state_shape = non_spec_state_indices_tensor.shape + ssm_state.shape[1: ] + initial_state = torch.empty(initial_state_shape, dtype=ssm_state.dtype, device=ssm_state.device) + for i in range(non_spec_state_indices_tensor.shape[0]): + initial_state[i] = ssm_state[non_spec_state_indices_tensor[i]] + + initial_state = initial_state * has_initial_state.view(has_initial_state.shape[0], 1, 1, 1) initial_state = initial_state.transpose(-1, -2).contiguous() - if self.num_v_heads // self.num_k_heads > 1: - query_non_spec = query_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) ( core_attn_out_non_spec, last_recurrent_state, @@ -635,9 +637,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): cu_seqlens=non_spec_query_start_loc, ) # Init cache - last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous() - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype) + last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view( + last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1]) + cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1]) + xtorch_ops.reshape_and_cache_flash( + last_recurrent_state, + last_recurrent_state, + cast_ssm_state, + cast_ssm_state, + non_spec_state_indices_tensor) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule(