support qwen3-next full_decode_only mode. (#3949)
### What this PR does / why we need it?
support qwen3-next full_decode_only mode.
bs=1, max_token=1024
| branch| tps| e2e time|
| --- | --- | --- |
|piecewise |3.06 | 8.15 |
|fulldecodeonly | 7.2 | 3.47 |
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -36,3 +36,21 @@ def test_models_distributed_Qwen3_NEXT_TP4():
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
] * 4
|
||||
max_tokens = 5
|
||||
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.8,
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=False,
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
|
||||
}) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
@@ -192,8 +192,10 @@ class ACLGraphWrapper:
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
# For Qwen3-next, since the kv_cache_config has already categorized
|
||||
# linear_attn and self_attn, the attn_metadata is first arranged with
|
||||
# self_attn followed by linear_attn. Therefore, using zip directly
|
||||
# filters out the update operations for linear_attn.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
@@ -289,9 +291,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
|
||||
|
||||
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
graph_params = get_graph_params()
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import numpy as np
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -178,6 +179,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
@@ -186,7 +188,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_dummy_attn_metadata(
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
aclgraph_runtime_mode, force_attention)
|
||||
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
|
||||
@@ -76,7 +76,8 @@ from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -107,7 +108,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
# yapf: disable
|
||||
@@ -2644,6 +2646,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
@@ -2659,6 +2662,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
|
||||
self.device).to(torch.int32)
|
||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
@@ -2715,12 +2726,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config.method == "deepseek_mtp":
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=seq_lens)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
if isinstance(builder, AscendAttentionMetadataBuilder):
|
||||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
elif isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
|
||||
common_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
if "linear_attn" in layer_name:
|
||||
attn_metadata[
|
||||
layer_name] = attn_metadata_gdn_attention
|
||||
else:
|
||||
attn_metadata[
|
||||
layer_name] = attn_metadata_full_attention
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@@ -2895,6 +2929,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_query_len=max_query_len,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
|
||||
Reference in New Issue
Block a user