support qwen3-next full_decode_only mode. (#3949)

### What this PR does / why we need it?
support qwen3-next full_decode_only mode. 
bs=1, max_token=1024
| branch| tps| e2e time|
| --- | --- | --- |
|piecewise  |3.06  | 8.15 |
|fulldecodeonly | 7.2 | 3.47 |

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-05 08:46:05 +08:00
committed by GitHub
parent 5f08e07208
commit 738bf2b720
4 changed files with 66 additions and 9 deletions

View File

@@ -76,7 +76,8 @@ from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
AttentionCGSupport, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
@@ -107,7 +108,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
AscendAttentionState)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
# yapf: disable
@@ -2644,6 +2646,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
@@ -2659,6 +2662,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2715,12 +2726,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
if isinstance(builder, AscendAttentionMetadataBuilder):
attn_metadata_full_attention = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
elif isinstance(builder, GDNAttentionMetadataBuilder):
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
common_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if "linear_attn" in layer_name:
attn_metadata[
layer_name] = attn_metadata_gdn_attention
else:
attn_metadata[
layer_name] = attn_metadata_full_attention
return attn_metadata
@@ -2895,6 +2929,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)
need_dummy_logits = (not self.in_profile_run